File size: 7,924 Bytes
02c783d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
# Copyright(C) [2025] Advanced Micro Devices, Inc. All rights reserved.
import os
import json
import subprocess
import ast

from ..constants import REPO_ROOT, TMP_ROOT
import re

DEFAULT_TRITON_BENCH_ROOT = os.path.join(REPO_ROOT, "data", "TritonBench", "data", "TritonBench_G_v1")


def extract_collection_error(stderr_string: str) -> str:
    """
    Extracts the content of a pytest collection error block from a stderr string.

    This is designed for errors that happen during test discovery (e.g.,
    syntax errors), which are reported in an "ERRORS" block.

    Args:
        stderr_string: The complete stderr output as a string.

    Returns:
        A string containing the collection error block, or an empty string if
        the specific block is not found.
    """
    # Use a regular expression to find the content between the ERRORS
    # and "short test summary" markers.
    # re.DOTALL makes the '.' special character match any character, including newlines.
    pattern = re.compile(
        r"={10,}\s+ERRORS\s+={10,}(.*?)\n={10,}\s+short test summary info\s+={10,}",
        re.DOTALL
    )

    match = pattern.search(stderr_string)

    if match:
        # group(1) contains the text captured by (.*?)
        # .strip() removes leading/trailing whitespace and newlines.
        return match.group(1).strip()
    else:
        return ""

def extract_first_pytest_failure(stderr_string: str) -> str:
    """
    Extracts the content of the first pytest failure block from a stderr string.

    Args:
        stderr_string: The complete stderr output as a string.

    Returns:
        A string containing the first failure block, or an empty string if
        no failure blocks are found.
    """
    lines = stderr_string.splitlines()
    
    # Regex to match the pytest failure start line pattern
    # e.g., ___________________ test_correctness[...] ___________________
    failure_start_pattern = re.compile(r'^_{3,} test_.* _{3,}$')
    
    first_start_index = -1
    # Find the index of the first failure marker
    for i, line in enumerate(lines):
        if failure_start_pattern.match(line):
            first_start_index = i
            break # Found the first one

    if first_start_index == -1:
        # No failure markers found
        return ""

    next_start_index = -1
    # Find the index of the next failure marker *after* the first one
    for i in range(first_start_index + 1, len(lines)):
        if failure_start_pattern.match(lines[i]):
            next_start_index = i
            break # Found the start of the next one

    # Extract the lines for the first failure block
    if next_start_index != -1:
        extracted_lines = lines[first_start_index : next_start_index]
    else:
        # If no next failure marker is found, extract till the end
        extracted_lines = lines[first_start_index :]

    return "\n".join(extracted_lines)

def extract_errors(stderr_string: str) -> str:
    """
    Extracts the primary error from a pytest stderr output, acting as an
    abstraction layer.

    It first checks for a fatal collection error. If none is found, it
    falls back to extracting the first runtime test failure.

    Args:
        stderr_string: The complete stderr output from a pytest run.

    Returns:
        A string containing the most relevant error block, or an empty string
        if no errors are found.
    """
    # Priority 1: Check for collection errors, as they are fatal and
    # prevent tests from running.
    collection_error = extract_collection_error(stderr_string)
    if collection_error:
        return collection_error

    # Priority 2: If no collection errors, check for standard runtime failures.
    runtime_failure = extract_first_pytest_failure(stderr_string)
    if runtime_failure:
        return runtime_failure
        
    # If neither type of error is found, return an empty string.
    return ""
    
def extract_json_from_stdout(stdout_string: str) -> dict:
    """
    Extracts a JSON object from a string that contains json in markdown format e.g. ```json\n .... \n```.

    Args:
        stdout_string: The complete stdout output as a string.

    Returns:
        A dictionary containing the extracted JSON data, or an empty dictionary
        if no valid JSON is found.
    """
    # Use a regular expression to find the content between the ```json
    # and the closing ```.
    pattern = re.compile(r'```json\s*(.*?)\s*```', re.DOTALL)

    match = pattern.search(stdout_string)

    if match:
        json_string = match.group(1).strip()
        try:
            return json.loads(json_string)
        except json.JSONDecodeError:
            return {}
    else:
        return {}

def get_fname_difficulty_from_label(label):
    # triton_root = DEFAULT_TRITON_BENCH_ROOT
    triton_root = os.path.join(REPO_ROOT, "data", "TritonBench", "data", "TritonBench_G_comp_alpac_v1_fixed_with_difficulty.json")
    with open(triton_root, 'r') as f:
        data = json.load(f)
        for item in data:
            if item['output'] == label:
                return item['file'], item['difficulty']
    return None, None

def run_shell(command, cwd=None, env=None, timeout=None):
    """
    Run a shell command and return the output.
    """
    if cwd is None:
        cwd = REPO_ROOT
    if env is None:
        env = os.environ.copy()
    
    result = subprocess.run(command, shell=True, cwd=cwd, env=env, capture_output=True, text=True, timeout=timeout)
    status = result.returncode == 0
    stdout = result.stdout.strip()
    stderr = result.stderr.strip()
    return status, stdout, stderr


class TestFunctionRemover(ast.NodeTransformer):
    def visit_FunctionDef(self, node):
        if node.name.startswith('test_'):
            return None  # Kill the function
        return self.generic_visit(node)

    def visit_Expr(self, node):
        if isinstance(node.value, ast.Call):
            func = node.value.func
            if isinstance(func, ast.Name) and func.id.startswith('test_'):
                return None  # Kill expressions like test_foo()
            if isinstance(func, ast.Attribute) and func.attr.startswith('test_'):
                return None
        return self.generic_visit(node)

    def visit_Assign(self, node):
        # If the value being assigned is a call to test_ function, kill the entire assignment
        if isinstance(node.value, ast.Call):
            func = node.value.func
            if isinstance(func, ast.Name) and func.id.startswith('test_'):
                return None
            if isinstance(func, ast.Attribute) and func.attr.startswith('test_'):
                return None
        return self.generic_visit(node)

    def visit_AugAssign(self, node):
        # For augmented assignments like x += test_func()
        if isinstance(node.value, ast.Call):
            func = node.value.func
            if isinstance(func, ast.Name) and func.id.startswith('test_'):
                return None
            if isinstance(func, ast.Attribute) and func.attr.startswith('test_'):
                return None
        return self.generic_visit(node)

    def visit_Module(self, node):
        # Manually rebuild body without None's
        node.body = [stmt for stmt in map(self.visit, node.body) if stmt is not None]
        return node

    def visit_ClassDef(self, node):
        node.body = [stmt for stmt in map(self.visit, node.body) if stmt is not None]
        return node

def strip_test_functions(source_code):
    tree = ast.parse(source_code)
    remover = TestFunctionRemover()
    tree = remover.visit(tree)
    ast.fix_missing_locations(tree)
    return ast.unparse(tree)

def process_code(code: str):
    if "```python" in code:
        code = code.split("```python")[-1].replace("<|im_end|>", "").replace("<|EOT|>", "")    
    try:
        code = strip_test_functions(code)
    except Exception as e:
        pass    
    return code