File size: 2,129 Bytes
02c783d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import click
from agents.GaAgent import GaAgent
from models.OpenAI import OpenAIModel
from models.Gemini import GeminiModel
from models.Claude import ClaudeModel
from models.Vllm import VLLMModel
from dataloaders.TritonBench import TritonBench
from args_config import load_config


@click.command()
@click.option('-c', '--config', default='')
def main(config):
    # config = 'configs/tritonbench_gaagent_config_new.yaml'
    # config = 'configs/tritonbench_gaagent_config_new_debug.yaml'
    args = load_config(config)

    # setup LLM model
    # model = ClaudeModel(api_key=args.api_key, model_id=args.model_id)
    # model = VLLMModel(api_key=args.api_key, model_id=args.model_id, base_url="http://localhost:8040/v1")
    model = VLLMModel(api_key=args.api_key, model_id=args.model_id, base_url=args.model_id)


    # setup dataset
    dataset = TritonBench(statis_path=args.statis_path, 
                          py_folder=args.py_folder, 
                          instruction_path=args.instruction_path, 
                          py_interpreter=args.py_interpreter, 
                          golden_metrics=args.golden_metrics,
                          perf_ref_folder=args.perf_ref_folder,
                          perf_G_path=args.perf_G_path,
                          result_path=args.result_path)

    # setup agent
    agent = GaAgent(model=model, dataset=dataset, corpus_path=args.corpus_path, mem_file=args.mem_file, descendant_num=args.descendant_num)

    # run the agent
    agent.run(output_path=args.output_path, 
              multi_thread=args.multi_thread, 
              iteration_num=args.max_iteration, 
              temperature=args.temperature, 
              datalen=args.datalen,
              gpu_id=args.gpu_id,
              start_iter=args.start_iter,
              ancestor_num=args.ancestor_num,
              descendant_num=args.descendant_num,
              descendant_debug=args.descendant_debug,
              target_gpu=args.target_gpu,
              profiling=args.profiling,
              start_idx=args.start_idx,
              args=args)


if __name__ == "__main__":
    main()