import textgrad as tg
import torch
import os 
import sys
sys.path.append('/lustre/fast/fast/txiao/zly/spatial_head/cot/textgrad')
os.environ["OPENAI_API_KEY"] = "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"
import pdb
from tqdm import tqdm
import argparse
from vllm import LLM, SamplingParams
import json 
from textgrad.engine.qwen_vllm import QwenCoderVLLM
from prompt import Domain_template
import subprocess



"""
Implementing textgrad for qwen 
"""
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_path', type=str, default='/lustre/fast/fast/txiao/zly/ckpt/Qwen25Coder7B', help='Path to model')
    parser.add_argument('--input_data', type=str, default='/lustre/fast/fast/txiao/zly/spatial_head/cot/test_textgrad.json', help='Path to data')
    parser.add_argument('--output_path', type=str, default='/lustre/fast/fast/txiao/zly/spatial_head/cot/result/textgrad/qwen_textgrad_test.json', help='Path to output')
    parser.add_argument('--batch_size', type=int, default=5, help='Batch size')
    parser.add_argument('--epoch', type=int, default=1, help='Number of process')
    parser.add_argument('--num_gpu', type=int, default=2, help='GPU numbers')
    parser.add_argument('--ith_gpu', type=int, default=0, help='GPU index')
    return parser.parse_args()

class Qwen_TextGrad:
    def __init__(self, args):
        self.model_path = args.model_path
        self.input_data = args.input_data
        # load the data
        self.data = json.load(open(self.input_data))
        # split the data into i-th part
        self.data = self.data[args.ith_gpu*len(self.data)//args.num_gpu:(args.ith_gpu+1)*len(self.data)//args.num_gpu]
        self.output_path = args.output_path
        self.batch_size = args.batch_size
        self.llm_engine = tg.get_engine(self.model_path)
        self.num_epoches = args.epoch

    def optimization(self):
        tg.set_backward_engine(self.llm_engine)
        # batchlify the dataset 
        batches = [self.data[i:i+self.batch_size] for i in range(0, len(self.data), self.batch_size)]
        all_results = []
        for epoch in range(self.num_epoches):
            results = []
            for batch in tqdm(batches):
                for data in batch:
                    domain_code = tg.Variable(value=data['domain'],
                                    requires_grad=True,
                                    role_description="pddl domain instance to optimize")
                    prompt = tg.Variable(value=Domain_template.format(NL_description=data['nl_description']),
                                requires_grad=False,
                                role_description="The question to optimize")

                    optimizer = tg.TGD(parameters=[domain_code])
                    optimizer.zero_grad()
                    loss = self.loss_fn(prompt, domain_code)
                    loss.backward()
                    try:
                        optimizer.step()
                    except Exception as e:
                        # pass
                        print(e)
                    result = {
                        "nl_description": data['nl_description'],
                        "file": data['file'],
                        "original_domain": data['domain'],
                        "epoch": epoch,
                        "loss": loss.value,
                        "updated_domain_code": domain_code.value
                    }
                    results.append(result)
            all_results.extend(results)
            # save the results to the output path
            with open(self.output_path, 'w') as f:
                json.dump(all_results, f, indent=4)
        

    def loss_fn(self, prompt: tg.Variable, domain_code: tg.Variable) -> tg.Variable:
        # The system prompt that will guide the behavior of the loss function.
        loss_system_prompt = "You are an expert in reviewing pddl domain code. You do not solve problems or propose new code snippets, only evaluate the grammar and logic of the pddl domain codes"
        loss_system_prompt = tg.Variable(loss_system_prompt, requires_grad=False, role_description="system prompt to the loss function")

        # The instruction that will be the prefix
        instruction = """
        Clearly check the grammar and logic of each predicates and actions in the pddl domain code. 
        Do you think the code is correct? If not, what is the problem?
        """

        # The format string and setting up the call
        format_string = "{instruction}\nProblem: {{prompt}}\nCurrent Code: {{domain_code}}"
        format_string = format_string.format(instruction=instruction)

        fields = {"prompt": None, "domain_code": None}
        formatted_llm_call = tg.autograd.FormattedLLMCall(engine=self.llm_engine,
                                                          format_string=format_string,
                                                          fields=fields,
                                                          system_prompt=loss_system_prompt)

        inputs = {"prompt": prompt, "domain_code": domain_code}
        loss = formatted_llm_call(inputs=inputs, response_role_description=f"evaluation of the {domain_code.get_role_description()}")
        return loss


# def single_process_textgrad(args, i):
#     data = json.load(open(args.input_data))
#     # split the data into i-th part
#     data = data[i*len(data)//args.epoch:(i+1)*len(data)//args.epoch]
#     qwen = Qwen_TextGrad(args)
#     qwen.optimization(data)

        
    



if __name__ == '__main__':
    args = parse_args()
    qwen = Qwen_TextGrad(args)
    qwen.optimization()

        