geak_eval / GEAK-agent_debug /src /agents /reflexion_oneshot.py
llmll's picture
Upload folder using huggingface_hub
02c783d verified
import os
from tqdm import tqdm
from loguru import logger
import json
from dataclasses import asdict
from agents.Reflexion import Reflexion
from utils.utils import extract_function_signatures, clear_code, extract_function_calls
from prompts import prompt_for_reflection
from memories.Memory import MemoryClassMeta
from models.Base import BaseModel
from retrievers.retriever import BM25Retriever
from prompts import prompt_for_generation
from concurrent.futures import ThreadPoolExecutor, as_completed
class Reflexion_Oneshot(Reflexion):
def __init__(self, model: BaseModel, dataset, corpus_path, mem_file=None, descendant_num=1):
self.model = model
self.dataset = dataset
self.memories = []
self.instruction_retriever = BM25Retriever()
self.instruction_retriever.process(content_input_path=corpus_path)
self.code_retriever = BM25Retriever(mode="code")
self.code_retriever.process(content_input_path=corpus_path)
self.memory_init(mem_file, descendant_num)
def memory_init(self, mem_file=None, descendant_num=1):
class Memory(metaclass=MemoryClassMeta, field_names=["ps",
"err_msg",
"reflection",
"function_signatures",
"oneshot",
"pass_call",
]):
pass
if mem_file is not None:
assert mem_file.endswith(".json"), f"expect a json file, but got {mem_file} instead"
with open(mem_file, "r") as f:
input_mems = json.load(f)
assert len(input_mems) == len(self.dataset), f"expect {len(self.dataset)} samples, but got {len(input_mems)} instead"
for ps in self.dataset.problem_states:
if ps.label:
fs_mem = extract_function_signatures(ps.label)
else:
fs_mem = None
if mem_file is None:
os_mem = self.instruction_retriever.query(ps.instruction)[0]
tmp_mem = Memory(ps=ps,
err_msg=None,
reflection=None,
function_signatures=fs_mem,
oneshot=os_mem["code"],
pass_call=False,
)
else:
input_mem = input_mems[ps.filename]
tmp_mem = Memory(ps=ps,
err_msg=input_mem["err_msg"],
reflection=input_mem["reflection"],
function_signatures=fs_mem,
oneshot=input_mem["oneshot"],
pass_call=input_mem["pass_call"],
)
self.memories.append(tmp_mem)
def run(self, output_path=None, multi_thread=True, verbose=False, datalen=None, iteration_num=0, temperature=0):
data_len = datalen if datalen else len(self.dataset)
for iter in range(iteration_num):
logger.info(f"\n=== Iteration {iter} ===")
if output_path is not None:
root, extension = os.path.splitext(output_path)
iter_path = f"{root}_{iter}{extension}"
if multi_thread:
thread_num = 3
# generate solution
logger.info(f"\ngenerate solution")
with tqdm(total=data_len) as pbar:
if multi_thread:
with ThreadPoolExecutor(max_workers=thread_num) as executor:
futures = {executor.submit(self.generate_solution, mem, temperature): mem for mem in self.memories[:data_len]}
for future in as_completed(futures):
pbar.update(1)
else:
for mem in self.memories[:data_len]:
self.generate_solution(mem, temperature=temperature)
pbar.update(1)
# run scripts
logger.info(f"\nrun scripts on gpu")
for mem in tqdm(self.memories[:data_len]):
if mem.pass_call:
continue
is_pass, err_msg = self.dataset.run_single_call(mem.ps)
if not is_pass:
mem.err_msg = err_msg
# generate reflections
logger.info(f"\ngenerate reflections")
with tqdm(total=data_len) as pbar:
if multi_thread:
with ThreadPoolExecutor(max_workers=thread_num) as executor:
futures = {executor.submit(self.generate_reflexion, mem, temperature): mem for mem in self.memories[:data_len]}
for future in as_completed(futures):
pbar.update(1)
else:
for mem in self.memories[:data_len]:
self.generate_reflexion(mem, temperature=temperature)
pbar.update(1)
if output_path is not None:
self.dataset.write_file(iter_path)
def generate_solution(self, mem, temperature=0):
if mem.pass_call:
return
tab = "\n"
fss_text = "".join(f"* {sig}{tab}" for sig in mem.function_signatures)
text = prompt_for_generation.prompt.format(
instruction=mem.ps.instruction,
function_signatures=fss_text
)
if not mem.ps.solution:
text += f"\nHere is an example snippet of code: {mem.oneshot}"
else:
one_shot = self.code_retriever.query(mem.ps.solution)[0]["code"]
text += f"\nHere is an example snippet of code: {one_shot}"
text += f"\nPrevious attempt implementation:{mem.ps.solution}"
if mem.err_msg:
text += f"\nTest messages for previous attempt:{mem.err_msg}"
if mem.reflection:
text += f"\nReflection on previous attempt:{mem.reflection}"
text += "Please output the codes only without explanation, which we can run directly."
msg = [
{"role": "user", "content": text},
]
response = self.model.generate(msg, temperature=temperature)
mem.ps.solution = clear_code(response)
return
def generate_reflexion(self, mem, temperature):
if mem.pass_call:
return
reflect_txt = prompt_for_reflection.prompt.format(
problem=mem.ps.instruction,
solution=mem.ps.solution,
test_result=mem.err_msg
)
reflect_msg = [
{
"role": "user",
"content": reflect_txt
}
]
mem.reflection = self.model.generate(reflect_msg, temperature=temperature)