import torch 
import multiprocessing
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
from prompt import Loss_Prompt, GD_Prompt, problem2domain_template
import argparse
import json
from tqdm import tqdm
import pdb 
import logging 
import os

os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_path', type=str, default='/lustre/fast/fast/txiao/zly/ckpt/Qwen25Coder7B', help='Path to model')
    parser.add_argument('--input_data', type=str, default='/lustre/fast/fast/txiao/zly/spatial_head/cot/benchmark/ipc_bench/domain_nl_bad.json', help='Path to data')
    parser.add_argument('--output_path', type=str, default='/lustre/fast/fast/txiao/zly/spatial_head/cot/result/textgrad/qwen_textgrad_1_nl.json', help='Path to output')
    parser.add_argument('--batch_size', type=int, default=640, help='Batch size')
    parser.add_argument('--epoch', type=int, default=1, help='Number of process')#1,2,3,4,5
    parser.add_argument('--num_gpu', type=int, default=4, help='GPU numbers')
    return parser.parse_args()
  
class Qwen_TextGrad_Batch:
    def __init__(self, args):
        self.model_path = args.model_path
        self.input_data = args.input_data
        # load the data
        self.data = json.load(open(self.input_data))
        self.output_path = args.output_path
        self.batch_size = args.batch_size
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
        self.sampling_params = SamplingParams(temperature=0.7,  max_tokens=2048)
        self.differentiable_llm = LLM(model=self.model_path, trust_remote_code=True, dtype="half", tensor_parallel_size=2)
        self.num_epoches = args.epoch

    def batch_optimization(self):
        batches = [self.data[i:i+self.batch_size] for i in range(0, len(self.data), self.batch_size)]
        all_results = []
        for i in range(self.num_epoches):
            round = i
            results = []
            for batch in tqdm(batches):
                nl_description_list = [data['nl_description'] for data in batch]
                domain_code_list = [data['domain'] for data in batch]
                Loss_prompts = []
                GD_prompts = []
                for i in range(len(batch)):
                    loss_prompt = self.PromptTemplate4Loss(nl_description=nl_description_list[i], domain_code=domain_code_list[i])
                    Loss_prompts.append(loss_prompt)
                losses = self.differentiable_llm.generate(Loss_prompts, self.sampling_params)
                loss_values = []
                for loss in losses:
                    loss_values.append(loss.outputs[0].text)
                for i in range(len(batch)):
                    gd_prompt = self.PromptTemplate4GD(error=loss_values[i], domain_code=domain_code_list[i])
                    GD_prompts.append(gd_prompt)
                gds = self.differentiable_llm.generate(GD_prompts, self.sampling_params)
                gd_values = []
                for gd in gds:
                    gd_values.append(gd.outputs[0].text)
                
                for i in range(len(batch)):
                    results.append({
                        'nl_description': nl_description_list[i],
                        'rounds': round, 
                        'file': batch[i]['file'],
                        'response_id': batch[i]['response_id'],
                        'old_domain': domain_code_list[i],
                        'loss': loss_values[i],
                        'domain': gd_values[i]
                    })
            all_results += results
        # save the results
        with open(self.output_path, 'w') as f:
            json.dump(all_results, f, indent=4)
    
    def batch_optimization_prob(self):
        batches = [self.data[i:i+self.batch_size] for i in range(0, len(self.data), self.batch_size)]
        all_results = []
        for i in range(self.num_epoches):
            round = i
            results = []
            for batch in tqdm(batches):                
                nl_description_list = [data['question'] for data in batch]
                domain_code_list = [data['domain'] for data in batch]
                Loss_prompts = []
                GD_prompts = []
                for i in range(len(batch)):
                    loss_prompt = self.PromptTemplate4Loss(nl_description=nl_description_list[i], domain_code=domain_code_list[i])
                    Loss_prompts.append(loss_prompt)
                losses = self.differentiable_llm.generate(Loss_prompts, self.sampling_params)
                loss_values = []
                for loss in losses:
                    loss_values.append(loss.outputs[0].text)
                for i in range(len(batch)):
                    gd_prompt = self.PromptTemplate4GD(error=loss_values[i], domain_code=domain_code_list[i])
                    GD_prompts.append(gd_prompt)
                gds = self.differentiable_llm.generate(GD_prompts, self.sampling_params)
                gd_values = []
                for gd in gds:
                    gd_values.append(gd.outputs[0].text)
                
                for i in range(len(batch)):
                    results.append({
                        'question': nl_description_list[i],
                        'rounds': round, 
                        'file': batch[i]['file'],
                        'response_id': batch[i]['response_id'],
                        'old_domain': domain_code_list[i],
                        'loss': loss_values[i],
                        'domain': gd_values[i]
                    })
            all_results += results
        # save the results
        with open(self.output_path, 'w') as f:
            json.dump(all_results, f, indent=4)

                    

    def PromptTemplate4Loss(self, nl_description: str, domain_code: str) -> str:
        Loss_prompt = Loss_Prompt.format(NL_Description=nl_description, PDDL_Domain=domain_code)
        messages = [
            {
                "role": "system",
                "content": "You are helpful assistant",
            },
            {"role": "user", "content": Loss_prompt}
        ]
        msg_prompt = self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        return msg_prompt
        

    def PromptTemplate4GD(self, error: str, domain_code: str) -> str:
        GD_prompt = GD_Prompt.format(PDDL_Domain=domain_code, Error=error)
        messages = [
            {
                "role": "system",
                "content": "You are helpful assistant",
            },
            {"role": "user", "content": GD_prompt}
        ]
        msg_prompt = self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        return msg_prompt

if __name__ == '__main__':
    args = parse_args()
    qwen = Qwen_TextGrad_Batch(args)
    qwen.batch_optimization()
    # batch_optimization_prob
    # if 'nl' in args.input_data:
    #     qwen.batch_optimization()
    # else:
    #     qwen.batch_optimization_prob()
    