import sys
import os 
sys.path.append('/lustre/fast/fast/txiao/zly/spatial_head/cot')
from prompt import Prompt_template, Domain_template, Problem_template, problem2domain_template
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch 
from vllm import LLM, SamplingParams
from typing import List
import argparse
import json 
from tqdm import tqdm
import pdb

os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'

def parse_args():
    parser = argparse.ArgumentParser('Generate the domain file for a given model')
    parser.add_argument('--model_name', type=str, default='/lustre/fast/fast/txiao/zly/ckpt/Qwen25Coder7B', help='The name of the model')
    parser.add_argument('--tokenizer_name', type=str, default='/lustre/fast/fast/txiao/zly/ckpt/Qwen25Coder7B', help='The name of the tokenizer')
    parser.add_argument('--data_path', type=str, default='/lustre/fast/fast/txiao/zly/spatial_head/cot/benchmark/planetarium/planetarium_test.json', help='The path to the data file')
    parser.add_argument('--output_path', type=str, default='/lustre/fast/fast/txiao/zly/spatial_head/cot/result/planetarium/coder16prob_new.json', help='The path to the output domain files')
    parser.add_argument('--num_samples', type=int, default=None)
    parser.add_argument('--n_SamplingParams', type=int, default=16, help='SamplingParams(n)')
    # batch size
    parser.add_argument('--batch_size', type=int, default=480)
    return parser.parse_args()

gripper = """
(define (domain gripper)
  (:requirements :strips)
  (:predicates
    (room ?r)
    (ball ?b)
    (gripper ?g)
    (at-robby ?r)
    (at ?b ?r)
    (free ?g)
    (carry ?o ?g)
  )

  (:action move
    :parameters (?from ?to)
    :precondition (and (room ?from) (room ?to) (at-robby ?from))
    :effect (and (at-robby ?to)
      (not (at-robby ?from)))
  )

  (:action pick
    :parameters (?obj ?room ?gripper)
    :precondition (and (ball ?obj) (room ?room) (gripper ?gripper)
      (at ?obj ?room) (at-robby ?room) (free ?gripper))
    :effect (and (carry ?obj ?gripper)
      (not (at ?obj ?room))
      (not (free ?gripper)))
  )

  (:action drop
    :parameters (?obj ?room ?gripper)
    :precondition (and (ball ?obj) (room ?room) (gripper ?gripper)
      (carry ?obj ?gripper) (at-robby ?room))
    :effect (and (at ?obj ?room)
      (free ?gripper)
      (not (carry ?obj ?gripper)))
  )
)
"""

block_world = """
(define (domain blocksworld)

  (:requirements :strips)

  (:predicates
    (clear ?x)
    (on-table ?x)
    (arm-empty)
    (holding ?x)
    (on ?x ?y)
  )

  (:action pickup
    :parameters (?ob)
    :precondition (and (clear ?ob) (on-table ?ob) (arm-empty))
    :effect (and (holding ?ob) (not (clear ?ob)) (not (on-table ?ob))
      (not (arm-empty)))
  )

  (:action putdown
    :parameters (?ob)
    :precondition (holding ?ob)
    :effect (and (clear ?ob) (arm-empty) (on-table ?ob)
      (not (holding ?ob)))
  )

  (:action stack
    :parameters (?ob ?underob)
    :precondition (and (clear ?underob) (holding ?ob))
    :effect (and (arm-empty) (clear ?ob) (on ?ob ?underob)
      (not (clear ?underob)) (not (holding ?ob)))
  )

  (:action unstack
    :parameters (?ob ?underob)
    :precondition (and (on ?ob ?underob) (clear ?ob) (arm-empty))
    :effect (and (holding ?ob) (clear ?underob)
      (not (on ?ob ?underob)) (not (clear ?ob)) (not (arm-empty)))
  )
)
"""

class Domain_Generate:
    def __init__(self, args):
        # if the model_name is a path exists in the local file system
        self.model_path = args.model_name
        self.data_path = args.data_path
        self.num_samples = args.num_samples
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.batch_size = args.batch_size
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
        self.n_SamplingParams = args.n_SamplingParams #defult=128,      1,8,32,64,128,512,1024
        if self.n_SamplingParams < 32:
            self.num_compute = self.n_SamplingParams # 1,8
        else:
            self.num_compute = 32 #32
        self.sampling_params = SamplingParams(n=self.num_compute, best_of=self.n_SamplingParams, temperature=0.7,  max_tokens=2048) # best_of > n
        self.model = LLM(model=self.model_path, trust_remote_code=True, dtype="half", tensor_parallel_size=4)
        self.output_path = args.output_path
    
    def generate_prob(self):
        if self.data_path.endswith('.jsonl'):
            with open(self.data_path, 'r') as f:
                data = [json.loads(line) for line in f.readlines()]
        elif self.data_path.endswith('.json'):
            with open(self.data_path, 'r') as f:
                data = json.load(f)
        else:
            raise ValueError('data_path must be a .jsonl or .json file')
        if self.num_samples is not None:
            data = data[:self.num_samples]
        batches = [data[i:i+self.batch_size] for i in range(0, len(data), self.batch_size)]
        results = []
        for batch in tqdm(batches):
            results += self.generate_domain_batch(batch)
        return results
    
    def apply_prompt_template(self, question: str):
        prompt_question = Problem_template.format(NL_description=question)
        messages = [
            {
                "role": "system",
                "content": "You are helpful assistant",
            },
            {"role": "user", "content": prompt_question}
        ]
        msg_prompt = self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        return msg_prompt

    def apply_prompt_template_problem(self, question: str):
        prompt_question = Problem_template.format(NL_Description=question)
        messages = [
            {
                "role": "system",
                "content": "You are helpful assistant",
            },
            {"role": "user", "content": prompt_question}
        ]
        msg_prompt = self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        return msg_prompt
    
    
    def generate_domain_batch(self, data):
        questions = [d['question'] for d in data]
        # questions = [self.apply_prompt_template(q) for q in questions] # for nl2domain
        questions = [self.apply_prompt_template_problem(q) for q in questions]
        answers = self.model.generate(questions, self.sampling_params)
        results = []
        for i in range(len(data)):
            for j in range(len(answers[i].outputs)):
                results.append({
                    'id': data[i]['id'],
                    'name': data[i]['name'],
                    'nl_description': data[i]['nl_description'],
                    'gt_question': data[i]['question'],
                    'response_id': j,
                    'domain': gripper if 'gripper' in data[i]['domain'] else block_world,
                    'init_is_abstract': data[i]['init_is_abstract'],
                    'goal_is_abstract': data[i]['goal_is_abstract'],
                    'question': answers[i].outputs[j].text
                })
        
        return results
    
    
if __name__ == '__main__':
    args = parse_args()
    print(args.output_path)
    dg = Domain_Generate(args)
    results = dg.generate_prob()
    if args.output_path.endswith('.json'):
        with open(args.output_path, 'w') as f:
            # dump with indent=4
            json.dump(results, f, indent=4)
    print(f'Domain files saved to {args.output_path}')

            
    
    