File size: 2,129 Bytes
02c783d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
import click
from agents.GaAgent import GaAgent
from models.OpenAI import OpenAIModel
from models.Gemini import GeminiModel
from models.Claude import ClaudeModel
from models.Vllm import VLLMModel
from dataloaders.TritonBench import TritonBench
from args_config import load_config
@click.command()
@click.option('-c', '--config', default='')
def main(config):
# config = 'configs/tritonbench_gaagent_config_new.yaml'
# config = 'configs/tritonbench_gaagent_config_new_debug.yaml'
args = load_config(config)
# setup LLM model
# model = ClaudeModel(api_key=args.api_key, model_id=args.model_id)
# model = VLLMModel(api_key=args.api_key, model_id=args.model_id, base_url="http://localhost:8040/v1")
model = VLLMModel(api_key=args.api_key, model_id=args.model_id, base_url=args.model_id)
# setup dataset
dataset = TritonBench(statis_path=args.statis_path,
py_folder=args.py_folder,
instruction_path=args.instruction_path,
py_interpreter=args.py_interpreter,
golden_metrics=args.golden_metrics,
perf_ref_folder=args.perf_ref_folder,
perf_G_path=args.perf_G_path,
result_path=args.result_path)
# setup agent
agent = GaAgent(model=model, dataset=dataset, corpus_path=args.corpus_path, mem_file=args.mem_file, descendant_num=args.descendant_num)
# run the agent
agent.run(output_path=args.output_path,
multi_thread=args.multi_thread,
iteration_num=args.max_iteration,
temperature=args.temperature,
datalen=args.datalen,
gpu_id=args.gpu_id,
start_iter=args.start_iter,
ancestor_num=args.ancestor_num,
descendant_num=args.descendant_num,
descendant_debug=args.descendant_debug,
target_gpu=args.target_gpu,
profiling=args.profiling,
start_idx=args.start_idx,
args=args)
if __name__ == "__main__":
main()
|