geak_eval / GEAK-agent_debug /src /main_parallel_scaling_tritonbench.py
llmll's picture
Upload folder using huggingface_hub
02c783d verified
import json
from tqdm import tqdm
from args_config import load_config
from dataloaders.TritonBench import TritonBench
def main():
args = load_config("configs/parallel_scaling_config.yaml")
file_list = args.file_list
output_file = args.output_file
assert output_file.endswith(".jsonl"), f"expect output file to be a jsonl file, but got {file} instead"
# setup dataset
dataset = TritonBench(statis_path=args.statis_path,
py_folder=args.py_folder,
instruction_path=args.instruction_path,
py_interpreter=args.py_interpreter,
golden_metrics=args.golden_metrics,
perf_ref_folder=args.perf_ref_folder,
perf_G_path=args.perf_G_path)
result_dict = {}
call_num = [0 for _ in range(args.data_len)]
exe_num = [0 for _ in range(args.data_len)]
for file in file_list:
assert file.endswith(".jsonl"), f"expect jsonl file, but got {file} instead"
with open(file, "r") as f:
lines = f.readlines()
assert args.data_len == len(lines), f"expect {args.data_len} entries, but got {len(lines)} entries instead for file {file}"
with tqdm(total=args.data_len) as pbar:
for i, line in enumerate(lines):
content = json.loads(line)
filename = content["filename"]
# if filename in result_dict and result_dict[filename]["pass_exe"]:
# pbar.update(1)
# continue
tmp_dir = "par_scaling_tmp"
exe_dir = "par_scaling_exe"
pass_call, pass_exe, call_stdout, call_stderr, exe_stdout, exe_stderr = dataset.test_opt_correctness(content["predict"],
filename=filename,
tmp_dir=tmp_dir,
save_scripts=True,
exe_dir=exe_dir)
ms = None
if pass_call:
call_num[i] = 1
if pass_exe:
exe_num[i] = 1
# import ipdb
# ipdb.set_trace()
result_dir = os.path.join(exe_dir, "result")
script_dir = os.path.join(exe_dir, "script")
log_dir = os.path.join(exe_dir, "log")
path_gen = os.path.join(result_dir, filename[:-3] + ".json")
path_ref = os.path.join(dataset.perf_ref_folder, filename[:-3] + ".json")
dataset.write_perf_file(input_folder_path=exe_dir, results_path=result_dir, tmp_dir=script_dir)
dataset.run_perf_scripts(script_dir=script_dir, log_dir=log_dir)
try:
_, _, ms = dataset.calculate(path_gen=path_gen, path_ref=path_ref)
except:
pass
if not filename in result_dict:
result_dict[filename] = {
"predict": content["predict"],
"pass_call": pass_call,
"pass_exe": pass_exe,
"instruction": content["instruction"],
"label": content["label"] if "label" in content else None,
"latency": ms
}
else:
if (not result_dict[filename]["pass_exe"]) and pass_exe:
result_dict[filename]["predict"] = content["predict"]
result_dict[filename]["pass_call"] = pass_call
result_dict[filename]["pass_exe"] = pass_exe
result_dict[filename]["latency"] = ms
elif (not result_dict[filename]["pass_call"]) and pass_call:
result_dict[filename]["predict"] = content["predict"]
result_dict[filename]["pass_call"] = pass_call
result_dict[filename]["pass_exe"] = pass_exe
result_dict[filename]["latency"] = ms
elif result_dict[filename]["pass_exe"] and ms:
if result_dict[filename]["latency"] is None or result_dict[filename]["latency"] > ms:
result_dict[filename]["predict"] = content["predict"]
result_dict[filename]["latency"] = ms
os.system(f'rm -rf {tmp_dir}')
os.system(f'rm -rf {exe_dir}')
pbar.update(1)
call_acc = sum(call_num) / 184.0
exe_acc = sum(exe_num) / 184.0
print(f"call acc: {call_acc}")
print(f"exe acc: {exe_acc}")
with open(output_file, "w") as f:
for filename, result in result_dict.items():
output = {
"instruction": result["instruction"],
"label": result["label"],
"filename": filename,
"predict": result["predict"]
}
f.write(json.dumps(output) + "\n")
if __name__ == "__main__":
main()