"""
Iteratively refine barman problem until it passes the fastdownward test.
TODO: 1. copy the textgrad code from textgrad_qwen.py enable string as input
TODO: 2. copy the problem from barman
TODO: 3. enable online fastdownward test
"""
import glob
import torch 
import multiprocessing
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
from prompt import Loss_Prompt, GD_Prompt, problem2domain_template
import argparse
import json
from tqdm import tqdm
import pdb 
import logging 
import os
import random
import subprocess
import re


prompt = '''
You will be given a natural language description of a planning problem. Your task is to translate this description into PDDL domain code. This includes defining predicates and actions based on the information provided.

Information about the AI agent will be provided in the natural language description. Note that individual conditions in preconditions and effects should be listed separately. For example, “object_1 is washed and heated” should be considered as two separate conditions “object_1 is washed” and “object_1 is heated”. Also, in PDDL, two predicates cannot have the same name even if they have different parameters. Each predicate in PDDL must have a unique name, and its parameters must be explicitly defined in the predicate definition. It is recommended to define predicate names in an intuitive and readable way. Remember: Ignore the information that you think is not helpful for the planning task.

You are only responsible for domain generation.
Before you generate the concrete domain code, you should first generate a natural language thought about the meaning of each variable, and the step-by-step explaination of the domain code.
Even if I didn't provide the exact name of the predicates and actions, you should generate them based on the information provided in the natural language description.

Template is:

### Thought:
predicates1: the name of predicate1, explanation of predictate1
...
predicaten: the name of predicaten, explanation of predictaten
action1: the name of action1, explanation of action1
...
actionn: the name of action, explanation of actionn
<thought>

### Domain:
```pddl
The concrete pddl code for domain.pddl 

Now its your time to generate the solution, you have to follow the format I provided above.
NL_Description:
You are a robot barman that manipulates drink dispensers, shot glasses and a shaker. You have two hands. The goal is to find a plan that serves a desired set of drinks. Here are the actions you can do

Grasp a container -- pddl action name: grasp
Leave a container on the table -- pddl action name: leave
Fill a shot glass with an ingredient -- pddl action name: fill-shot
Refill a shot glass with an ingredient -- pddl action name: refill-shot
Empty a shot glass -- pddl action name: empty-shot
Clean a shot glass -- pddl action name: clean-shot
Pour an ingredient from a shot glass to a clean shaker -- pddl action name: pour-shot-to-clean-shaker
Pour an ingredient from a shot glass to a used shaker -- pddl action name: pour-shot-to-used-shaker
Empty a shaker -- pddl action name: empty-shaker
Clean a shaker -- pddl action name: clean-shaker
Shake a cocktail in a shaker -- pddl action name: shake
Pour from a shaker to a shot glass -- pddl action name: pour-shaker-to-shot 

Once you grasp a container, you are holding the container and the container is not on the table.
Once you leave a container on the table, your hand become empty.
Once you pour an ingredient from a shot glass to a shaker, the shaker contains the ingredient and is at one level above the previous level, and the shot glass becomes empty.
Once you empty a shaker, the shaker is at the empty level.
Once you shake, the two ingredients in the shaker become a cocktail.
Once you pour from a shaker to a shot glass, the shot glass contains the beverage in the shaker, the shot glass is no longer clean and empty, and the shaker is at one level below the previous level. 

An example problem PDDL problem file to the domain is: 
 (define (problem prob)
 (:domain barman)
 (:objects 
     shaker1 - shaker
     left right - hand
     shot1 shot2 shot3 shot4 shot5 shot6 shot7 shot8 shot9 shot10 shot11 shot12 shot13 shot14 shot15 - shot
     ingredient1 ingredient2 ingredient3 ingredient4 - ingredient
     cocktail1 cocktail2 cocktail3 cocktail4 cocktail5 cocktail6 cocktail7 cocktail8 cocktail9 cocktail10 cocktail11 - cocktail
     dispenser1 dispenser2 dispenser3 dispenser4 - dispenser
     l0 l1 l2 - level
)
 (:init 
  (ontable shaker1)
  (ontable shot1)
  (ontable shot2)
  (ontable shot3)
  (ontable shot4)
  (ontable shot5)
  (ontable shot6)
  (ontable shot7)
  (ontable shot8)
  (ontable shot9)
  (ontable shot10)
  (ontable shot11)
  (ontable shot12)
  (ontable shot13)
  (ontable shot14)
  (ontable shot15)
  (dispenses dispenser1 ingredient1)
  (dispenses dispenser2 ingredient2)
  (dispenses dispenser3 ingredient3)
  (dispenses dispenser4 ingredient4)
  (clean shaker1)
  (clean shot1)
  (clean shot2)
  (clean shot3)
  (clean shot4)
  (clean shot5)
  (clean shot6)
  (clean shot7)
  (clean shot8)
  (clean shot9)
  (clean shot10)
  (clean shot11)
  (clean shot12)
  (clean shot13)
  (clean shot14)
  (clean shot15)
  (empty shaker1)
  (empty shot1)
  (empty shot2)
  (empty shot3)
  (empty shot4)
  (empty shot5)
  (empty shot6)
  (empty shot7)
  (empty shot8)
  (empty shot9)
  (empty shot10)
  (empty shot11)
  (empty shot12)
  (empty shot13)
  (empty shot14)
  (empty shot15)
  (handempty left)
  (handempty right)
  (shaker-empty-level shaker1 l0)
  (shaker-level shaker1 l0)
  (next l0 l1)
  (next l1 l2)
  (cocktail-part1 cocktail1 ingredient2)
  (cocktail-part2 cocktail1 ingredient4)
  (cocktail-part1 cocktail2 ingredient3)
  (cocktail-part2 cocktail2 ingredient1)
  (cocktail-part1 cocktail3 ingredient1)
  (cocktail-part2 cocktail3 ingredient4)
  (cocktail-part1 cocktail4 ingredient4)
  (cocktail-part2 cocktail4 ingredient1)
  (cocktail-part1 cocktail5 ingredient4)
  (cocktail-part2 cocktail5 ingredient2)
  (cocktail-part1 cocktail6 ingredient3)
  (cocktail-part2 cocktail6 ingredient2)
  (cocktail-part1 cocktail7 ingredient2)
  (cocktail-part2 cocktail7 ingredient1)
  (cocktail-part1 cocktail8 ingredient2)
  (cocktail-part2 cocktail8 ingredient4)
  (cocktail-part1 cocktail9 ingredient4)
  (cocktail-part2 cocktail9 ingredient3)
  (cocktail-part1 cocktail10 ingredient1)
  (cocktail-part2 cocktail10 ingredient2)
  (cocktail-part1 cocktail11 ingredient3)
  (cocktail-part2 cocktail11 ingredient1)
)
 (:goal
  (and
     (contains shot1 cocktail4)
     (contains shot2 cocktail6)
     (contains shot3 cocktail11)
     (contains shot4 cocktail8)
     (contains shot5 cocktail7)
     (contains shot6 cocktail9)
     (contains shot7 cocktail2)
     (contains shot8 cocktail3)
     (contains shot9 cocktail5)
     (contains shot10 cocktail1)
     (contains shot11 cocktail10)
     (contains shot12 cocktail2)
     (contains shot13 ingredient2)
     (contains shot14 ingredient3)
)))


'''


problem = '''
    The problem PDDL file to this problem is: 
 (define (problem prob)
 (:domain barman)
 (:objects 
     shaker1 - shaker
     left right - hand
     shot1 shot2 shot3 - shot
     ingredient1 ingredient2 ingredient3 - ingredient
     cocktail1 - cocktail
     dispenser1 dispenser2 dispenser3 - dispenser
     l0 l1 l2 - level
)
 (:init 
  (ontable shaker1)
  (ontable shot1)
  (ontable shot2)
  (ontable shot3)
  (dispenses dispenser1 ingredient1)
  (dispenses dispenser2 ingredient2)
  (dispenses dispenser3 ingredient3)
  (clean shaker1)
  (clean shot1)
  (clean shot2)
  (clean shot3)
  (empty shaker1)
  (empty shot1)
  (empty shot2)
  (empty shot3)
  (handempty left)
  (handempty right)
  (shaker-empty-level shaker1 l0)
  (shaker-level shaker1 l0)
  (next l0 l1)
  (next l1 l2)
  (cocktail-part1 cocktail1 ingredient3)
  (cocktail-part2 cocktail1 ingredient1)
)
 (:goal
  (and
     (contains shot1 cocktail1)
))) 
'''

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_path', type=str, default='/lustre/fast/fast/txiao/zly/ckpt/Qwen25Coder7B', help='Path to model')
    parser.add_argument('--input_data', type=str, default='/lustre/fast/fast/txiao/zly/spatial_head/cot/result/case_study/case_barman.json', help='Path to data')
    # parser.add_argument('--input_data', type=str, default='/lustre/fast/fast/txiao/zly/spatial_head/cot/result/case_study/case_barman_test.json', help='Path to data')
    parser.add_argument('--output_path', type=str, default='/lustre/fast/fast/txiao/zly/spatial_head/cot/result/textgrad/barman_textgrad.json', help='Path to output')
    parser.add_argument('--batch_size', type=int, default=200, help='Batch size')
    parser.add_argument('--epoch', type=int, default=100, help='Number of process')
    parser.add_argument('--num_gpu', type=int, default=4, help='GPU numbers')
    parser.add_argument('--ith_gpu', type=int, default=0, help='GPU index')
    return parser.parse_args()

class textgrad:
    def __init__(self, args):
        self.model_path = args.model_path
        self.input_data = args.input_data
        # load the data
        self.data = json.load(open(self.input_data))
        self.output_path = args.output_path
        self.batch_size = args.batch_size
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
        self.sampling_params = SamplingParams(temperature=0.7,  max_tokens=2048)
        self.differentiable_llm = LLM(model=self.model_path, trust_remote_code=True, dtype="half", tensor_parallel_size=4)
        self.num_epoches = args.epoch

    def batch_optimization(self):
        batches = [self.data[i:i+self.batch_size] for i in range(0, len(self.data), self.batch_size)]
        all_results = []
        for i in range(self.num_epoches):
            round = i
            results = []
            for batch in tqdm(batches):
                nl_description_list = [data['nl_description'] for data in batch]
                domain_code_list = [data['domain'] for data in batch]
                Loss_prompts = []
                GD_prompts = []
                for i in range(len(batch)):
                    loss_prompt = self.PromptTemplate4Loss(nl_description=nl_description_list[i], domain_code=domain_code_list[i])
                    Loss_prompts.append(loss_prompt)
                losses = self.differentiable_llm.generate(Loss_prompts, self.sampling_params)
                loss_values = []
                for loss in losses:
                    loss_values.append(loss.outputs[0].text)
                for i in range(len(batch)):
                    gd_prompt = self.PromptTemplate4GD(error=loss_values[i], domain_code=domain_code_list[i])
                    GD_prompts.append(gd_prompt)
                gds = self.differentiable_llm.generate(GD_prompts, self.sampling_params)
                gd_values = []
                for gd in gds:
                    gd_values.append(gd.outputs[0].text)
                
                for i in range(len(batch)):
                    find_domain = self.check_fastdownward(domain_code_list[i])
                    results.append({
                        'question': nl_description_list[i],
                        'rounds': round, 
                        'file': batch[i]['file'],
                        'response_id': batch[i]['response_id'],
                        'old_domain': domain_code_list[i],
                        'loss': loss_values[i],
                        'domain': gd_values[i],
                        'fastdownward': find_domain
                    })
            all_results += results
        # save the results
        with open(self.output_path, 'w') as f:
            json.dump(all_results, f, indent=4)
    
    def batch_optimization_prob(self):
        batches = [self.data[i:i+self.batch_size] for i in range(0, len(self.data), self.batch_size)]
        all_results = []
        for i in range(self.num_epoches):
            round = i
            results = []
            for batch in tqdm(batches):                
                nl_description_list = [data['question'] for data in batch]
                domain_code_list = [data['domain'] for data in batch]
                Loss_prompts = []
                GD_prompts = []
                for i in range(len(batch)):
                    loss_prompt = self.PromptTemplate4Loss(nl_description=nl_description_list[i], domain_code=domain_code_list[i])
                    Loss_prompts.append(loss_prompt)
                losses = self.differentiable_llm.generate(Loss_prompts, self.sampling_params)
                loss_values = []
                for loss in losses:
                    loss_values.append(loss.outputs[0].text)
                for i in range(len(batch)):
                    gd_prompt = self.PromptTemplate4GD(error=loss_values[i], domain_code=domain_code_list[i])
                    GD_prompts.append(gd_prompt)
                gds = self.differentiable_llm.generate(GD_prompts, self.sampling_params)
                gd_values = []
                for gd in gds:
                    gd_values.append(gd.outputs[0].text)
                
                for i in range(len(batch)):
                    find_domain = self.check_fastdownward(domain_code_list[i])
                    results.append({
                        'question': nl_description_list[i],
                        'rounds': round, 
                        'file': batch[i]['file'],
                        'response_id': batch[i]['response_id'],
                        'old_domain': domain_code_list[i],
                        'loss': loss_values[i],
                        'domain': gd_values[i],
                        'fastdownward': find_domain
                    })
            all_results += results
        # save the results
        with open(self.output_path, 'w') as f:
            json.dump(all_results, f, indent=4)

                    

    def PromptTemplate4Loss(self, nl_description: str, domain_code: str) -> str:
        Loss_prompt = Loss_Prompt.format(NL_Description=nl_description, PDDL_Domain=domain_code)
        messages = [
            {
                "role": "system",
                "content": "You are helpful assistant",
            },
            {"role": "user", "content": Loss_prompt}
        ]
        msg_prompt = self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        return msg_prompt
        

    def PromptTemplate4GD(self, error: str, domain_code: str) -> str:
        GD_prompt = GD_Prompt.format(PDDL_Domain=domain_code, Error=error)
        messages = [
            {
                "role": "system",
                "content": "You are helpful assistant",
            },
            {"role": "user", "content": GD_prompt}
        ]
        msg_prompt = self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        return msg_prompt

    def check_fastdownward(self, domain_code: str) -> bool:
        # save the domain code to a file
        # TODO: please check the usage of fastdownward
        temp_dir = '/lustre/fast/fast/txiao/zly/spatial_head/cot/tmp/'
        prob_file = '/lustre/fast/fast/txiao/zly/spatial_head/cot/benchmark/case_study/barman/instances/prob1.pddl'
        # use random 16 digits as the file name
        file_name = ''.join(random.choices('0123456789', k=16))
        domain_file = os.path.join(temp_dir, file_name + '.pddl')
        # select after ```pddl
        extract_domain = re.search(r'```pddl\n([\s\S]*)\n```', domain_code).group(1)
        with open(domain_file, 'w') as f:
            f.write(extract_domain)
        plan_file = f'/lustre/fast/fast/txiao/zly/spatial_head/cot/tmp/plan_barman_{file_name}'
        # run the fastdownward with astar search
        command = f'python /lustre/fast/fast/txiao/zly/downward/fast-downward.py --plan-file {plan_file} --search-time-limit 20s  {domain_file} {prob_file} --search "astar(lmcut())"  > nul 2>&1'
        # run the command
        subprocess.run(command, shell=True)
        plan_files = glob.glob(f'{plan_file}*')
        if len(plan_files)>0:
            return True
        else:
            print("Fail!")
            os.system(f'rm {domain_file}')
            return False
    
        
        
if __name__ == '__main__':
    args = parse_args()
    qwen = textgrad(args)
    # qwen.batch_optimization()
    # batch_optimization_prob
    qwen.batch_optimization()
