|
|
|
|
|
import os |
|
|
import json |
|
|
import subprocess |
|
|
import ast |
|
|
|
|
|
from ..constants import REPO_ROOT, TMP_ROOT |
|
|
import re |
|
|
|
|
|
DEFAULT_TRITON_BENCH_ROOT = os.path.join(REPO_ROOT, "data", "TritonBench", "data", "TritonBench_G_v1") |
|
|
|
|
|
|
|
|
def extract_collection_error(stderr_string: str) -> str: |
|
|
""" |
|
|
Extracts the content of a pytest collection error block from a stderr string. |
|
|
|
|
|
This is designed for errors that happen during test discovery (e.g., |
|
|
syntax errors), which are reported in an "ERRORS" block. |
|
|
|
|
|
Args: |
|
|
stderr_string: The complete stderr output as a string. |
|
|
|
|
|
Returns: |
|
|
A string containing the collection error block, or an empty string if |
|
|
the specific block is not found. |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
pattern = re.compile( |
|
|
r"={10,}\s+ERRORS\s+={10,}(.*?)\n={10,}\s+short test summary info\s+={10,}", |
|
|
re.DOTALL |
|
|
) |
|
|
|
|
|
match = pattern.search(stderr_string) |
|
|
|
|
|
if match: |
|
|
|
|
|
|
|
|
return match.group(1).strip() |
|
|
else: |
|
|
return "" |
|
|
|
|
|
def extract_first_pytest_failure(stderr_string: str) -> str: |
|
|
""" |
|
|
Extracts the content of the first pytest failure block from a stderr string. |
|
|
|
|
|
Args: |
|
|
stderr_string: The complete stderr output as a string. |
|
|
|
|
|
Returns: |
|
|
A string containing the first failure block, or an empty string if |
|
|
no failure blocks are found. |
|
|
""" |
|
|
lines = stderr_string.splitlines() |
|
|
|
|
|
|
|
|
|
|
|
failure_start_pattern = re.compile(r'^_{3,} test_.* _{3,}$') |
|
|
|
|
|
first_start_index = -1 |
|
|
|
|
|
for i, line in enumerate(lines): |
|
|
if failure_start_pattern.match(line): |
|
|
first_start_index = i |
|
|
break |
|
|
|
|
|
if first_start_index == -1: |
|
|
|
|
|
return "" |
|
|
|
|
|
next_start_index = -1 |
|
|
|
|
|
for i in range(first_start_index + 1, len(lines)): |
|
|
if failure_start_pattern.match(lines[i]): |
|
|
next_start_index = i |
|
|
break |
|
|
|
|
|
|
|
|
if next_start_index != -1: |
|
|
extracted_lines = lines[first_start_index : next_start_index] |
|
|
else: |
|
|
|
|
|
extracted_lines = lines[first_start_index :] |
|
|
|
|
|
return "\n".join(extracted_lines) |
|
|
|
|
|
def extract_errors(stderr_string: str) -> str: |
|
|
""" |
|
|
Extracts the primary error from a pytest stderr output, acting as an |
|
|
abstraction layer. |
|
|
|
|
|
It first checks for a fatal collection error. If none is found, it |
|
|
falls back to extracting the first runtime test failure. |
|
|
|
|
|
Args: |
|
|
stderr_string: The complete stderr output from a pytest run. |
|
|
|
|
|
Returns: |
|
|
A string containing the most relevant error block, or an empty string |
|
|
if no errors are found. |
|
|
""" |
|
|
|
|
|
|
|
|
collection_error = extract_collection_error(stderr_string) |
|
|
if collection_error: |
|
|
return collection_error |
|
|
|
|
|
|
|
|
runtime_failure = extract_first_pytest_failure(stderr_string) |
|
|
if runtime_failure: |
|
|
return runtime_failure |
|
|
|
|
|
|
|
|
return "" |
|
|
|
|
|
def extract_json_from_stdout(stdout_string: str) -> dict: |
|
|
""" |
|
|
Extracts a JSON object from a string that contains json in markdown format e.g. ```json\n .... \n```. |
|
|
|
|
|
Args: |
|
|
stdout_string: The complete stdout output as a string. |
|
|
|
|
|
Returns: |
|
|
A dictionary containing the extracted JSON data, or an empty dictionary |
|
|
if no valid JSON is found. |
|
|
""" |
|
|
|
|
|
|
|
|
pattern = re.compile(r'```json\s*(.*?)\s*```', re.DOTALL) |
|
|
|
|
|
match = pattern.search(stdout_string) |
|
|
|
|
|
if match: |
|
|
json_string = match.group(1).strip() |
|
|
try: |
|
|
return json.loads(json_string) |
|
|
except json.JSONDecodeError: |
|
|
return {} |
|
|
else: |
|
|
return {} |
|
|
|
|
|
def get_fname_difficulty_from_label(label): |
|
|
|
|
|
triton_root = os.path.join(REPO_ROOT, "data", "TritonBench", "data", "TritonBench_G_comp_alpac_v1_fixed_with_difficulty.json") |
|
|
with open(triton_root, 'r') as f: |
|
|
data = json.load(f) |
|
|
for item in data: |
|
|
if item['output'] == label: |
|
|
return item['file'], item['difficulty'] |
|
|
return None, None |
|
|
|
|
|
def run_shell(command, cwd=None, env=None, timeout=None): |
|
|
""" |
|
|
Run a shell command and return the output. |
|
|
""" |
|
|
if cwd is None: |
|
|
cwd = REPO_ROOT |
|
|
if env is None: |
|
|
env = os.environ.copy() |
|
|
|
|
|
result = subprocess.run(command, shell=True, cwd=cwd, env=env, capture_output=True, text=True, timeout=timeout) |
|
|
status = result.returncode == 0 |
|
|
stdout = result.stdout.strip() |
|
|
stderr = result.stderr.strip() |
|
|
return status, stdout, stderr |
|
|
|
|
|
|
|
|
class TestFunctionRemover(ast.NodeTransformer): |
|
|
def visit_FunctionDef(self, node): |
|
|
if node.name.startswith('test_'): |
|
|
return None |
|
|
return self.generic_visit(node) |
|
|
|
|
|
def visit_Expr(self, node): |
|
|
if isinstance(node.value, ast.Call): |
|
|
func = node.value.func |
|
|
if isinstance(func, ast.Name) and func.id.startswith('test_'): |
|
|
return None |
|
|
if isinstance(func, ast.Attribute) and func.attr.startswith('test_'): |
|
|
return None |
|
|
return self.generic_visit(node) |
|
|
|
|
|
def visit_Assign(self, node): |
|
|
|
|
|
if isinstance(node.value, ast.Call): |
|
|
func = node.value.func |
|
|
if isinstance(func, ast.Name) and func.id.startswith('test_'): |
|
|
return None |
|
|
if isinstance(func, ast.Attribute) and func.attr.startswith('test_'): |
|
|
return None |
|
|
return self.generic_visit(node) |
|
|
|
|
|
def visit_AugAssign(self, node): |
|
|
|
|
|
if isinstance(node.value, ast.Call): |
|
|
func = node.value.func |
|
|
if isinstance(func, ast.Name) and func.id.startswith('test_'): |
|
|
return None |
|
|
if isinstance(func, ast.Attribute) and func.attr.startswith('test_'): |
|
|
return None |
|
|
return self.generic_visit(node) |
|
|
|
|
|
def visit_Module(self, node): |
|
|
|
|
|
node.body = [stmt for stmt in map(self.visit, node.body) if stmt is not None] |
|
|
return node |
|
|
|
|
|
def visit_ClassDef(self, node): |
|
|
node.body = [stmt for stmt in map(self.visit, node.body) if stmt is not None] |
|
|
return node |
|
|
|
|
|
def strip_test_functions(source_code): |
|
|
tree = ast.parse(source_code) |
|
|
remover = TestFunctionRemover() |
|
|
tree = remover.visit(tree) |
|
|
ast.fix_missing_locations(tree) |
|
|
return ast.unparse(tree) |
|
|
|
|
|
def process_code(code: str): |
|
|
if "```python" in code: |
|
|
code = code.split("```python")[-1].replace("<|im_end|>", "").replace("<|EOT|>", "") |
|
|
try: |
|
|
code = strip_test_functions(code) |
|
|
except Exception as e: |
|
|
pass |
|
|
return code |
|
|
|