llmll's picture
Upload folder using huggingface_hub
02c783d verified
# Copyright(C) [2025] Advanced Micro Devices, Inc. All rights reserved.
import os
import json
import subprocess
import ast
from ..constants import REPO_ROOT, TMP_ROOT
import re
DEFAULT_TRITON_BENCH_ROOT = os.path.join(REPO_ROOT, "data", "TritonBench", "data", "TritonBench_G_v1")
def extract_collection_error(stderr_string: str) -> str:
"""
Extracts the content of a pytest collection error block from a stderr string.
This is designed for errors that happen during test discovery (e.g.,
syntax errors), which are reported in an "ERRORS" block.
Args:
stderr_string: The complete stderr output as a string.
Returns:
A string containing the collection error block, or an empty string if
the specific block is not found.
"""
# Use a regular expression to find the content between the ERRORS
# and "short test summary" markers.
# re.DOTALL makes the '.' special character match any character, including newlines.
pattern = re.compile(
r"={10,}\s+ERRORS\s+={10,}(.*?)\n={10,}\s+short test summary info\s+={10,}",
re.DOTALL
)
match = pattern.search(stderr_string)
if match:
# group(1) contains the text captured by (.*?)
# .strip() removes leading/trailing whitespace and newlines.
return match.group(1).strip()
else:
return ""
def extract_first_pytest_failure(stderr_string: str) -> str:
"""
Extracts the content of the first pytest failure block from a stderr string.
Args:
stderr_string: The complete stderr output as a string.
Returns:
A string containing the first failure block, or an empty string if
no failure blocks are found.
"""
lines = stderr_string.splitlines()
# Regex to match the pytest failure start line pattern
# e.g., ___________________ test_correctness[...] ___________________
failure_start_pattern = re.compile(r'^_{3,} test_.* _{3,}$')
first_start_index = -1
# Find the index of the first failure marker
for i, line in enumerate(lines):
if failure_start_pattern.match(line):
first_start_index = i
break # Found the first one
if first_start_index == -1:
# No failure markers found
return ""
next_start_index = -1
# Find the index of the next failure marker *after* the first one
for i in range(first_start_index + 1, len(lines)):
if failure_start_pattern.match(lines[i]):
next_start_index = i
break # Found the start of the next one
# Extract the lines for the first failure block
if next_start_index != -1:
extracted_lines = lines[first_start_index : next_start_index]
else:
# If no next failure marker is found, extract till the end
extracted_lines = lines[first_start_index :]
return "\n".join(extracted_lines)
def extract_errors(stderr_string: str) -> str:
"""
Extracts the primary error from a pytest stderr output, acting as an
abstraction layer.
It first checks for a fatal collection error. If none is found, it
falls back to extracting the first runtime test failure.
Args:
stderr_string: The complete stderr output from a pytest run.
Returns:
A string containing the most relevant error block, or an empty string
if no errors are found.
"""
# Priority 1: Check for collection errors, as they are fatal and
# prevent tests from running.
collection_error = extract_collection_error(stderr_string)
if collection_error:
return collection_error
# Priority 2: If no collection errors, check for standard runtime failures.
runtime_failure = extract_first_pytest_failure(stderr_string)
if runtime_failure:
return runtime_failure
# If neither type of error is found, return an empty string.
return ""
def extract_json_from_stdout(stdout_string: str) -> dict:
"""
Extracts a JSON object from a string that contains json in markdown format e.g. ```json\n .... \n```.
Args:
stdout_string: The complete stdout output as a string.
Returns:
A dictionary containing the extracted JSON data, or an empty dictionary
if no valid JSON is found.
"""
# Use a regular expression to find the content between the ```json
# and the closing ```.
pattern = re.compile(r'```json\s*(.*?)\s*```', re.DOTALL)
match = pattern.search(stdout_string)
if match:
json_string = match.group(1).strip()
try:
return json.loads(json_string)
except json.JSONDecodeError:
return {}
else:
return {}
def get_fname_difficulty_from_label(label):
# triton_root = DEFAULT_TRITON_BENCH_ROOT
triton_root = os.path.join(REPO_ROOT, "data", "TritonBench", "data", "TritonBench_G_comp_alpac_v1_fixed_with_difficulty.json")
with open(triton_root, 'r') as f:
data = json.load(f)
for item in data:
if item['output'] == label:
return item['file'], item['difficulty']
return None, None
def run_shell(command, cwd=None, env=None, timeout=None):
"""
Run a shell command and return the output.
"""
if cwd is None:
cwd = REPO_ROOT
if env is None:
env = os.environ.copy()
result = subprocess.run(command, shell=True, cwd=cwd, env=env, capture_output=True, text=True, timeout=timeout)
status = result.returncode == 0
stdout = result.stdout.strip()
stderr = result.stderr.strip()
return status, stdout, stderr
class TestFunctionRemover(ast.NodeTransformer):
def visit_FunctionDef(self, node):
if node.name.startswith('test_'):
return None # Kill the function
return self.generic_visit(node)
def visit_Expr(self, node):
if isinstance(node.value, ast.Call):
func = node.value.func
if isinstance(func, ast.Name) and func.id.startswith('test_'):
return None # Kill expressions like test_foo()
if isinstance(func, ast.Attribute) and func.attr.startswith('test_'):
return None
return self.generic_visit(node)
def visit_Assign(self, node):
# If the value being assigned is a call to test_ function, kill the entire assignment
if isinstance(node.value, ast.Call):
func = node.value.func
if isinstance(func, ast.Name) and func.id.startswith('test_'):
return None
if isinstance(func, ast.Attribute) and func.attr.startswith('test_'):
return None
return self.generic_visit(node)
def visit_AugAssign(self, node):
# For augmented assignments like x += test_func()
if isinstance(node.value, ast.Call):
func = node.value.func
if isinstance(func, ast.Name) and func.id.startswith('test_'):
return None
if isinstance(func, ast.Attribute) and func.attr.startswith('test_'):
return None
return self.generic_visit(node)
def visit_Module(self, node):
# Manually rebuild body without None's
node.body = [stmt for stmt in map(self.visit, node.body) if stmt is not None]
return node
def visit_ClassDef(self, node):
node.body = [stmt for stmt in map(self.visit, node.body) if stmt is not None]
return node
def strip_test_functions(source_code):
tree = ast.parse(source_code)
remover = TestFunctionRemover()
tree = remover.visit(tree)
ast.fix_missing_locations(tree)
return ast.unparse(tree)
def process_code(code: str):
if "```python" in code:
code = code.split("```python")[-1].replace("<|im_end|>", "").replace("<|EOT|>", "")
try:
code = strip_test_functions(code)
except Exception as e:
pass
return code