Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -10,6 +10,49 @@ from transformers import pipeline
|
|
| 10 |
classifier = pipeline(model="Yozhikoff/arxiv-topics-distilbert-base-cased")
|
| 11 |
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
import re
|
| 14 |
import urllib.request
|
| 15 |
import xml.etree.ElementTree as ET
|
|
@@ -50,10 +93,54 @@ def classify_paper(title, abstract):
|
|
| 50 |
input_tensor = torch.tensor(item['input_ids'])[None]
|
| 51 |
logits = classifier.model(input_tensor).logits[0]
|
| 52 |
preds = torch.sigmoid(logits).detach().cpu().numpy()
|
| 53 |
-
result = {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
return result
|
| 55 |
|
| 56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
with gr.Blocks(title='Paper classifier') as demo:
|
| 58 |
gr.Markdown('# Paper Topic Classifier')
|
| 59 |
with gr.Row():
|
|
|
|
| 10 |
classifier = pipeline(model="Yozhikoff/arxiv-topics-distilbert-base-cased")
|
| 11 |
|
| 12 |
|
| 13 |
+
import re
|
| 14 |
+
import urllib.request
|
| 15 |
+
import xml.etree.ElementTree as ET
|
| 16 |
+
|
| 17 |
+
def get_arxiv_title_and_abstract(link):
|
| 18 |
+
try:
|
| 19 |
+
# Validate the arxiv link
|
| 20 |
+
pattern = r'^https?://arxiv.org/(abs|pdf)/(\d{4}\.\d{4,5})(\.pdf)?$'
|
| 21 |
+
match = re.match(pattern, link)
|
| 22 |
+
if not match:
|
| 23 |
+
raise ValueError('Invalid arxiv link')
|
| 24 |
+
|
| 25 |
+
# Construct the arxiv API URL
|
| 26 |
+
arxiv_id = match.group(2)
|
| 27 |
+
api_url = f'http://export.arxiv.org/api/query?id_list={arxiv_id}'
|
| 28 |
+
|
| 29 |
+
# Send a request to the arxiv API
|
| 30 |
+
response = urllib.request.urlopen(api_url)
|
| 31 |
+
xml_data = response.read()
|
| 32 |
+
|
| 33 |
+
# Parse the XML data
|
| 34 |
+
root = ET.fromstring(xml_data)
|
| 35 |
+
entry = root.find('{http://www.w3.org/2005/Atom}entry')
|
| 36 |
+
title = entry.find('{http://www.w3.org/2005/Atom}title').text
|
| 37 |
+
summary = entry.find('{http://www.w3.org/2005/Atom}summary').text
|
| 38 |
+
|
| 39 |
+
return title, summary
|
| 40 |
+
except:
|
| 41 |
+
raise gr.Error('Invalid arXiv URL!')
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
import gradio as gr
|
| 45 |
+
import xml.etree.ElementTree as ET
|
| 46 |
+
import re
|
| 47 |
+
import urllib
|
| 48 |
+
import torch
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
from transformers import pipeline
|
| 52 |
+
|
| 53 |
+
classifier = pipeline(model="Yozhikoff/arxiv-topics-distilbert-base-cased")
|
| 54 |
+
|
| 55 |
+
|
| 56 |
import re
|
| 57 |
import urllib.request
|
| 58 |
import xml.etree.ElementTree as ET
|
|
|
|
| 93 |
input_tensor = torch.tensor(item['input_ids'])[None]
|
| 94 |
logits = classifier.model(input_tensor).logits[0]
|
| 95 |
preds = torch.sigmoid(logits).detach().cpu().numpy()
|
| 96 |
+
result = {}
|
| 97 |
+
for num, prob in enumerate(preds):
|
| 98 |
+
if prob < 0.25:
|
| 99 |
+
continue
|
| 100 |
+
if classifier.model.config.id2label[num] in result:
|
| 101 |
+
if result[classifier.model.config.id2label[num]] > prob:
|
| 102 |
+
continue
|
| 103 |
+
result[classifier.model.config.id2label[num]] = float(prob)
|
| 104 |
return result
|
| 105 |
|
| 106 |
|
| 107 |
+
with gr.Blocks(title='Paper classifier') as demo:
|
| 108 |
+
gr.Markdown('# Paper Topic Classifier')
|
| 109 |
+
with gr.Row():
|
| 110 |
+
with gr.Column():
|
| 111 |
+
gr.Markdown('## Inputs')
|
| 112 |
+
gr.Markdown('#### Please enter an arXiv link **OR** fill title and abstract manually')
|
| 113 |
+
arxiv_link = gr.Textbox(label="Arxiv link", placeholder="Flip this text")
|
| 114 |
+
|
| 115 |
+
b1 = gr.Button("Parse Link")
|
| 116 |
+
|
| 117 |
+
title = gr.Textbox(label="Paper title", placeholder="Title text")
|
| 118 |
+
abstract = gr.Textbox(label="Paper abstract", placeholder="Abstract text")
|
| 119 |
+
|
| 120 |
+
b2 = gr.Button("Classify Paper", variant='primary')
|
| 121 |
+
|
| 122 |
+
b1.click(fn=get_arxiv_title_and_abstract, inputs=arxiv_link, outputs=[title, abstract], api_name="parse")
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
with gr.Column():
|
| 126 |
+
gr.Markdown('## Topics')
|
| 127 |
+
gr.Markdown('## ')
|
| 128 |
+
gr.Markdown('## ')
|
| 129 |
+
out = gr.Label(label="Topics")
|
| 130 |
+
b2.click(classify_paper, inputs=[title, abstract], outputs=out)
|
| 131 |
+
|
| 132 |
+
gr.Markdown('## Examples')
|
| 133 |
+
gr.Examples(
|
| 134 |
+
examples=[['https://arxiv.org/abs/1706.03762'], ['https://arxiv.org/abs/2304.06718'], ['https://arxiv.org/abs/1307.0058']],
|
| 135 |
+
inputs=arxiv_link,
|
| 136 |
+
outputs=[title, abstract],
|
| 137 |
+
fn=get_arxiv_title_and_abstract,
|
| 138 |
+
cache_examples=True,
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
demo.launch()
|
| 142 |
+
|
| 143 |
+
|
| 144 |
with gr.Blocks(title='Paper classifier') as demo:
|
| 145 |
gr.Markdown('# Paper Topic Classifier')
|
| 146 |
with gr.Row():
|