import sys
import os 
sys.path.append('/lustre/fast/fast/txiao/zly/spatial_head/cot')
from prompt import Prompt_template, Domain_template, Problem_template, problem2domain_template
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch 
from vllm import LLM, SamplingParams
from typing import List
import argparse
import json 
from tqdm import tqdm
import pdb

os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'

def parse_args():
    parser = argparse.ArgumentParser('Generate the domain file for a given model')
    parser.add_argument('--model_name', type=str, default='/lustre/fast/fast/txiao/zly/ckpt/Qwen25Coder7B', help='The name of the model')
    parser.add_argument('--tokenizer_name', type=str, default='/lustre/fast/fast/txiao/zly/ckpt/Qwen25Coder7B', help='The name of the tokenizer')
    parser.add_argument('--data_path', type=str, default='/lustre/fast/fast/txiao/zly/spatial_head/cot/benchmark/ipc_bench/domain_nl.json', help='The path to the data file')
    parser.add_argument('--output_path', type=str, default='/lustre/fast/fast/txiao/zly/spatial_head/cot/result/Domain_gen/coder128prob_ipc.json', help='The path to the output domain files')
    parser.add_argument('--num_samples', type=int, default=None)
    parser.add_argument('--n_SamplingParams', type=int, default=128, help='SamplingParams(n)')
    # batch size
    parser.add_argument('--batch_size', type=int, default=480)
    return parser.parse_args()

class Domain_Generate:
    def __init__(self, args):
        # if the model_name is a path exists in the local file system
        self.model_path = args.model_name
        self.data_path = args.data_path
        self.num_samples = args.num_samples
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.batch_size = args.batch_size
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
        self.n_SamplingParams = args.n_SamplingParams #defult=128,      1,8,32,64,128,512,1024
        if self.n_SamplingParams < 32:
            self.num_compute = self.n_SamplingParams # 1,8
        else:
            self.num_compute = 32 #32
        self.sampling_params = SamplingParams(n=self.num_compute, best_of=self.n_SamplingParams, temperature=0.7,  max_tokens=4096) # best_of > n
        self.model = LLM(model=self.model_path, trust_remote_code=True, dtype="half", tensor_parallel_size=8, disable_custom_all_reduce=True)
        self.output_path = args.output_path
    
    def generate_domain(self):
        if self.data_path.endswith('.jsonl'):
            with open(self.data_path, 'r') as f:
                data = [json.loads(line) for line in f.readlines()]
        elif self.data_path.endswith('.json'):
            with open(self.data_path, 'r') as f:
                data = json.load(f)
        else:
            raise ValueError('data_path must be a .jsonl or .json file')
        
        if self.num_samples is not None:
            data = data[:self.num_samples]
        batches = [data[i:i+self.batch_size] for i in range(0, len(data), self.batch_size)]
        results = []
        for batch in tqdm(batches):
            results += self.generate_domain_batch(batch)
        return results
    
    def apply_prompt_template(self, question: str):
        prompt_question = Domain_template.format(NL_description=question)
        messages = [
            {
                "role": "system",
                "content": "You are helpful assistant",
            },
            {"role": "user", "content": prompt_question}
        ]
        msg_prompt = self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        return msg_prompt

    def apply_prompt_template_problem(self, question: str):
        prompt_question = problem2domain_template.format(PDDL_problem_code=question)
        messages = [
            {
                "role": "system",
                "content": "You are helpful assistant",
            },
            {"role": "user", "content": prompt_question}
        ]
        msg_prompt = self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        return msg_prompt
    
    
    def generate_domain_batch(self, data):
        questions = [d['nl_description'] for d in data]
        # questions = [self.apply_prompt_template(q) for q in questions] # for nl2domain
        questions = [self.apply_prompt_template_problem(q) for q in questions]
        answers = self.model.generate(questions, self.sampling_params)
        results = []
        for i in range(len(data)):
            for j in range(len(answers[i].outputs)):
                results.append({
                    'nl_description': data[i]['nl_description'],
                    'file': data[i]['file'],
                    'response_id': j,
                    'domain': answers[i].outputs[j].text
                })
        return results

    def generate_domain_from_problem(self):
        if self.data_path.endswith('.jsonl'):
            with open(self.data_path, 'r') as f:
                data = [json.loads(line) for line in f.readlines()]
        elif self.data_path.endswith('.json'):
            with open(self.data_path, 'r') as f:
                data = json.load(f)
        else:
            raise ValueError('data_path must be a .jsonl or .json file')
        if self.num_samples is not None:
            data = data[:self.num_samples]
        batches = [data[i:i+self.batch_size] for i in range(0, len(data), self.batch_size)]
        results = []
        for batch in tqdm(batches):
            results += self.generate_domain_batch_from_problem(batch)
        return results

    def generate_domain_batch_from_problem(self, data):
        questions = [d['code'] for d in data]
        # questions = [self.apply_prompt_template(q) for q in questions] # for nl2domain
        questions = [self.apply_prompt_template_problem(q) for q in questions]
        answers = self.model.generate(questions, self.sampling_params)
        results = []
        for i in range(len(data)):
            for j in range(len(answers[i].outputs)):
                results.append({
                    'question': data[i]['code'],
                    'file': data[i]['file'],
                    'response_id': j,
                    'domain': answers[i].outputs[j].text
                })
        return results
    
    
if __name__ == '__main__':
    args = parse_args()
    print(args.output_path)
    dg = Domain_Generate(args)
    if 'nl' in args.data_path:
        results = dg.generate_domain()
    else:
        results = dg.generate_domain_from_problem()
    # results = dg.generate_domain()
    if args.output_path.endswith('.json'):
        with open(args.output_path, 'w') as f:
            # dump with indent=4
            json.dump(results, f, indent=4)
    print(f'Domain files saved to {args.output_path}')

            
    
    