|
|
import os |
|
|
from tqdm import tqdm |
|
|
from loguru import logger |
|
|
import json |
|
|
from dataclasses import asdict |
|
|
from agents.Reflexion import Reflexion |
|
|
from utils.utils import extract_function_signatures, clear_code, extract_function_calls |
|
|
from prompts import prompt_for_reflection |
|
|
from memories.Memory import MemoryClassMeta |
|
|
from models.Base import BaseModel |
|
|
from retrievers.retriever import BM25Retriever |
|
|
from prompts import prompt_for_generation |
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed |
|
|
|
|
|
|
|
|
|
|
|
class Reflexion_Oneshot(Reflexion): |
|
|
|
|
|
def __init__(self, model: BaseModel, dataset, corpus_path, mem_file=None, descendant_num=1): |
|
|
self.model = model |
|
|
self.dataset = dataset |
|
|
self.memories = [] |
|
|
|
|
|
self.instruction_retriever = BM25Retriever() |
|
|
self.instruction_retriever.process(content_input_path=corpus_path) |
|
|
self.code_retriever = BM25Retriever(mode="code") |
|
|
self.code_retriever.process(content_input_path=corpus_path) |
|
|
|
|
|
self.memory_init(mem_file, descendant_num) |
|
|
|
|
|
def memory_init(self, mem_file=None, descendant_num=1): |
|
|
class Memory(metaclass=MemoryClassMeta, field_names=["ps", |
|
|
"err_msg", |
|
|
"reflection", |
|
|
"function_signatures", |
|
|
"oneshot", |
|
|
"pass_call", |
|
|
]): |
|
|
pass |
|
|
|
|
|
if mem_file is not None: |
|
|
assert mem_file.endswith(".json"), f"expect a json file, but got {mem_file} instead" |
|
|
with open(mem_file, "r") as f: |
|
|
input_mems = json.load(f) |
|
|
assert len(input_mems) == len(self.dataset), f"expect {len(self.dataset)} samples, but got {len(input_mems)} instead" |
|
|
|
|
|
for ps in self.dataset.problem_states: |
|
|
if ps.label: |
|
|
fs_mem = extract_function_signatures(ps.label) |
|
|
else: |
|
|
fs_mem = None |
|
|
if mem_file is None: |
|
|
os_mem = self.instruction_retriever.query(ps.instruction)[0] |
|
|
tmp_mem = Memory(ps=ps, |
|
|
err_msg=None, |
|
|
reflection=None, |
|
|
function_signatures=fs_mem, |
|
|
oneshot=os_mem["code"], |
|
|
pass_call=False, |
|
|
) |
|
|
else: |
|
|
input_mem = input_mems[ps.filename] |
|
|
tmp_mem = Memory(ps=ps, |
|
|
err_msg=input_mem["err_msg"], |
|
|
reflection=input_mem["reflection"], |
|
|
function_signatures=fs_mem, |
|
|
oneshot=input_mem["oneshot"], |
|
|
pass_call=input_mem["pass_call"], |
|
|
) |
|
|
self.memories.append(tmp_mem) |
|
|
|
|
|
def run(self, output_path=None, multi_thread=True, verbose=False, datalen=None, iteration_num=0, temperature=0): |
|
|
data_len = datalen if datalen else len(self.dataset) |
|
|
for iter in range(iteration_num): |
|
|
logger.info(f"\n=== Iteration {iter} ===") |
|
|
if output_path is not None: |
|
|
root, extension = os.path.splitext(output_path) |
|
|
iter_path = f"{root}_{iter}{extension}" |
|
|
|
|
|
if multi_thread: |
|
|
thread_num = 3 |
|
|
|
|
|
|
|
|
logger.info(f"\ngenerate solution") |
|
|
with tqdm(total=data_len) as pbar: |
|
|
if multi_thread: |
|
|
|
|
|
with ThreadPoolExecutor(max_workers=thread_num) as executor: |
|
|
futures = {executor.submit(self.generate_solution, mem, temperature): mem for mem in self.memories[:data_len]} |
|
|
for future in as_completed(futures): |
|
|
pbar.update(1) |
|
|
else: |
|
|
for mem in self.memories[:data_len]: |
|
|
self.generate_solution(mem, temperature=temperature) |
|
|
pbar.update(1) |
|
|
|
|
|
|
|
|
logger.info(f"\nrun scripts on gpu") |
|
|
for mem in tqdm(self.memories[:data_len]): |
|
|
if mem.pass_call: |
|
|
continue |
|
|
is_pass, err_msg = self.dataset.run_single_call(mem.ps) |
|
|
if not is_pass: |
|
|
mem.err_msg = err_msg |
|
|
|
|
|
|
|
|
logger.info(f"\ngenerate reflections") |
|
|
with tqdm(total=data_len) as pbar: |
|
|
if multi_thread: |
|
|
with ThreadPoolExecutor(max_workers=thread_num) as executor: |
|
|
futures = {executor.submit(self.generate_reflexion, mem, temperature): mem for mem in self.memories[:data_len]} |
|
|
for future in as_completed(futures): |
|
|
pbar.update(1) |
|
|
else: |
|
|
for mem in self.memories[:data_len]: |
|
|
self.generate_reflexion(mem, temperature=temperature) |
|
|
pbar.update(1) |
|
|
|
|
|
if output_path is not None: |
|
|
self.dataset.write_file(iter_path) |
|
|
|
|
|
|
|
|
|
|
|
def generate_solution(self, mem, temperature=0): |
|
|
if mem.pass_call: |
|
|
return |
|
|
|
|
|
tab = "\n" |
|
|
fss_text = "".join(f"* {sig}{tab}" for sig in mem.function_signatures) |
|
|
text = prompt_for_generation.prompt.format( |
|
|
instruction=mem.ps.instruction, |
|
|
function_signatures=fss_text |
|
|
) |
|
|
|
|
|
if not mem.ps.solution: |
|
|
text += f"\nHere is an example snippet of code: {mem.oneshot}" |
|
|
else: |
|
|
one_shot = self.code_retriever.query(mem.ps.solution)[0]["code"] |
|
|
text += f"\nHere is an example snippet of code: {one_shot}" |
|
|
text += f"\nPrevious attempt implementation:{mem.ps.solution}" |
|
|
|
|
|
|
|
|
if mem.err_msg: |
|
|
text += f"\nTest messages for previous attempt:{mem.err_msg}" |
|
|
|
|
|
if mem.reflection: |
|
|
text += f"\nReflection on previous attempt:{mem.reflection}" |
|
|
|
|
|
text += "Please output the codes only without explanation, which we can run directly." |
|
|
msg = [ |
|
|
{"role": "user", "content": text}, |
|
|
] |
|
|
response = self.model.generate(msg, temperature=temperature) |
|
|
mem.ps.solution = clear_code(response) |
|
|
|
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
def generate_reflexion(self, mem, temperature): |
|
|
if mem.pass_call: |
|
|
return |
|
|
reflect_txt = prompt_for_reflection.prompt.format( |
|
|
problem=mem.ps.instruction, |
|
|
solution=mem.ps.solution, |
|
|
test_result=mem.err_msg |
|
|
) |
|
|
reflect_msg = [ |
|
|
{ |
|
|
"role": "user", |
|
|
"content": reflect_txt |
|
|
} |
|
|
] |
|
|
mem.reflection = self.model.generate(reflect_msg, temperature=temperature) |