# Copyright (c) 2024-present, Royal Bank of Canada.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

import os
import numpy as np

from pddl_utils import get_problem_pddl_empty_goal, extract_atom_arguments
from utils import postprocess, safe_function_execute
import subprocess
from utils import get_random_temp_file_name, read_and_remove_file, as_file
import logging
from typing import List
import fast_downward
from fast_downward import Atom, Operator, close_lib
import pdb
import yaml
import sys
from LLaMA_Factory.src.llamafactory.chat.chat_model import ChatModel
import json
import prompts
import torch
import gc
import sys
import re

DOMAIN_NAMES = [
    "barman", "blocksworld", "floortile", "grippers", "grippers-ood", "storage",
    "termes", "tyreworld", "manipulation", "childsnack-opt14-strips",
    'depot', 'driverlog', 'hiking-agl14-strips', 'logistics00', 'miconic', 'movie', 'mprime', 'openstacks',
    'parking-opt11-strips', 'rovers', 'satellite', 'scanalyzer-08-strips', 'trucks', 'zenotravel',
    "frozenlake", "maze", "sokoban", "maze_wall", "printer", "package", "overcooked"
]


class Domain:
    def __init__(self, base_path, name, vlm_path, gpt_client=None):
        assert name in DOMAIN_NAMES
        self.name = name
        self.domain_dir = os.path.join(base_path, name)
        self.tasks = []  # should be list of tuples like (descritpion, ground_truth_pddl)
        with open(vlm_path, 'r') as file:
            vlm_args = yaml.safe_load(file)
        
        self.chat_model = ChatModel(vlm_args)
        self.vlm_args = vlm_args

        self.grab_tasks() # this is not useful, as we won't have predefined problem files, we only have initial observation images
        self.pddls = [None] * len(self.tasks)
        self.nls = [None] * len(self.tasks)
        self.templates = [None] * len(self.tasks)
        self.image_paths = [None] * len(self.tasks)
        self.gpt_client = gpt_client
        # pdb.set_trace()

    def grab_tasks(self):
        path = self.domain_dir
        p_pddls = []
        for i in range(1, 1000):
            pddl = f"p{i:03d}.nl" #pddl not provided, so use nl
            if os.path.isfile(os.path.join(path, pddl)) or os.path.isfile(os.path.join(path, pddl.replace('nl', 'pddl'))):
                p_pddls.append(pddl)
        # pdb.set_trace()
        sorted_pddls = sorted(p_pddls)
        self.tasks = sorted_pddls

    def __len__(self):
        return len(self.tasks)

    def get_task_suffix(self, i): #unused
        nl, pddl = self.tasks[i]
        return f"{self.name}/{pddl}"

    def get_task_file(self, i):
        # print(len(self.tasks))
        pddl = self.tasks[i]
        return os.path.join(self.domain_dir, pddl)

    def get_domain_predicate_descriptor(self):
        with open(os.path.join(self.domain_dir, "predicate_descriptor.py"), 'r') as f:
            return f.read()
    
    def transform_string(self, text):
    # Extract agent positions
        agents = set(re.findall(r"pos-\d+-\d+", re.search(r"agents are at \(([^)]+)\)", text).group(1)))
        
        # Find and update clear positions sentence
        def update_clear(match):
            clear_pos = [p.strip() for p in match.group(1).split(", ") if p.strip() not in agents]
            return f"The positions {', '.join(clear_pos)} are clear, other positions are counters and agents."
        
        return re.sub(r"The positions ([^.]+) are clear, other positions are counters\.", update_clear, text)

    def get_nl_from_vlm(self, img_path):
        print('getting nl for ', img_path)
        # pdb.set_trace()
        def message_construct_llama_func(user_prompt_list, response_total_list):
            messages = []
            for i in range(len(user_prompt_list)):
                messages.append({"role": "user", "content": user_prompt_list[i]})
                if i < len(user_prompt_list) - 1:
                    messages.append({"role": "assistant", "content": response_total_list[i]})
            return messages
        # if 'test' in img_path and os.path.exists(img_path.replace('.png', 'generated.nl')):
        if os.path.exists(img_path.replace('.png', 'generated.nl')): 
            with open(img_path.replace('.png', 'generated.nl'), 'r') as f:
                nl = f.read()
            return nl
        json_path = img_path.replace(".png", "_seq.json")
        # pdb.set_trace()
        # print(json_path)
        with open(json_path, 'r') as file:  
            data = json.load(file)
        messages = message_construct_llama_func([data['conversations'][0]['value']], [])
        response_parts = []
        with torch.no_grad():
            for new_text in self.chat_model.stream_chat(messages,images=data['images'], skip_special_tokens = True):
                response_parts.append(new_text)
        response = "".join(response_parts)
        img_number = img_path.split('.png')[0].split('p')[-1]
        print(response)
        # pdb.set_trace()
        
        # instruction = 'You are tasked with manipulating an agent to reach the goal without falling into the ice hole. The position representation is (row, column) representation. For example, (pos-4-3) represents the position in fourth row and third column. The left upper corner is (pos-1-1). You can perform four actions: move-up, move-down, move-left, and move-right.\n\n'
        instruction = ''
        nl = instruction + 'Initially:' + response.split('Initially:')[1].split('Step 1')[0]
        with open(img_path.replace('.png', 'generated.nl'), "w") as file:
            file.write(nl)
        del response  # Delete large tensors
        torch.cuda.empty_cache()  # Clear GPU cache
        gc.collect()
        return nl

    def get_nl_from_api(self, img_path):
        json_path = img_path.replace(".png", "_seq.json")
        # pdb.set_trace()
        example_path = img_path[:-8] + "example_task_desc.nl"
        with open(json_path, 'r') as file:  
            data = json.load(file)
        system_message = 'Given a sequence of actions and an image observation of the initial setup of a scenario, your goal is to decide the describe the setup you observe from the image.'
        if 'frozenlake' in img_path:
            with open(example_path, 'r') as file:  
                few_shot_messages = file.read()
        messages = 'Task Description' + data['conversations'][0]['value'].split('\n\nTask Description')[1].split('Action Sequence')[0]
        # pdb.set_trace()
        conv_id, _ = self.gpt_client.make_new_chat(system_message=system_message)
        conv_id, gpt_output, _ = self.gpt_client.complete_one_chat(conv_id, (few_shot_messages+messages, prompts.encode_image(img_path)))
        print(img_path)
        print(gpt_output)
        instruction = 'You are tasked with manipulating an agent to reach the goal without falling into the ice hole. The position representation is (row, column) representation. For example, (pos-4-3) represents the position in fourth row and third column. The left upper corner is (pos-1-1). You can perform four actions: move-up, move-down, move-left, and move-right.\n\n'
        nl = instruction + gpt_output
        return nl

    def get_task(self, i):
        # pdb.set_trace()
        if self.nls[i] is not None:
            return self.pddls[i], self.nls[i], self.templates[i], self.image_paths[i]
        pddl_f = self.get_task_file(i) #filename
        if os.path.exists(pddl_f.replace('.nl', '.pddl')):
            with open(pddl_f.replace('.nl', '.pddl'), 'r') as f: #text task description + problem description
                pddl = f.read()
        else:
            pddl = None
        with open(pddl_f.replace(".nl", "_template.pddl"), 'r') as f: # pddl template(contains #object, object names)
            template = f.read()
        if self.name == 'blocksworld':
            img_path = os.path.join(self.domain_dir, 'image.jpg')
            with open(pddl_f, 'r') as f: #text task description + problem description
                nl = f.read()
        elif self.name == 'maze':
            img_path = os.path.join(self.domain_dir, 'image.png')
            with open(pddl_f, 'r') as f: #text task description + problem description
                nl = f.read()
        # elif self.name == 'sokoban':
        #     img_path = pddl_f.replace(".nl", ".png")
        #     nl = self.get_nl_from_vlm(img_path)
        else:
            img_path = pddl_f.replace(".nl", ".png")
            if self.gpt_client is not None and self.gpt_client != 'vila-u':
                nl = self.get_nl_from_api(img_path)
            elif self.gpt_client == 'vila-u':
                nl = self.get_nl_from_vila_u(img_path)
            else:
                nl = self.get_nl_from_vlm(img_path) #need to change back for normal testing!!!
                # nl = ''
            # nl = self.get_nl_from_api(img_path, gpt_client=gpt_client)
        # pdb.set_trace()
        self.pddls[i] = pddl
        self.nls[i] = postprocess(nl)
        self.templates[i] = postprocess(template)
        self.image_paths[i] = img_path
        return pddl, postprocess(nl), postprocess(template), img_path

    def get_task_nl(self, i):
        if self.nls[i] is not None:
            return self.nls[i]
        return self.get_task(i)[1]
    
    def get_task_image(self, i):
        if self.image_paths[i] is not None:
            return self.image_paths[i]
        return self.get_task(i)[3]

    def get_domain_pddl(self):
        domain_pddl_f = self.get_domain_pddl_file()
        with open(domain_pddl_f, 'r') as f:
            domain_pddl = f.read()
        return postprocess(domain_pddl)

    def get_domain_template_pddl(self):
        domain_pddl_path = os.path.join(self.domain_dir, "domain_template.pddl")
        with open(domain_pddl_path, 'r') as f:
            domain_pddl = f.read()
        return postprocess(domain_pddl)

    def get_domain_pddl_file(self):
        domain_pddl_f = os.path.join(self.domain_dir, "domain.pddl")
        return domain_pddl_f

    def get_domain_nl(self):
        domain_nl_f = self.get_domain_nl_file()
        try:
            with open(domain_nl_f, 'r') as f:
                domain_nl = f.read()
        except:
            raise Exception(f"Could not read domain nl file: {domain_nl_f}")
        return postprocess(domain_nl)

    def get_task_pddl(self, i):
        print(i)
        if self.pddls[i] is not None:
            return self.pddls[i]
        return self.get_task(i)[0]

    def get_domain_nl_file(self):
        domain_name = "domain.nl"
        domain_nl_f = os.path.join(self.domain_dir, domain_name)
        return domain_nl_f

    def get_task_template(self, i):
        if self.templates[i] is not None:
            return self.templates[i]
        return self.get_task(i)[2]


class PDDLEnv:
    OPTIMAL_ALIAS = "seq-opt-fdss-1"
    SUB_OPTIMAL_ALIAS = "lama-first"

    def __init__(
            self, fd_py_path: str, val_bin_path: str, fd_search_time_limit: int, fd_alias: str = SUB_OPTIMAL_ALIAS
    ) -> None:
        self.fd_py_path = fd_py_path
        self.fd_search_time_limit = fd_search_time_limit
        self.val_bin_path = val_bin_path
        self.fd_alias = fd_alias

    def search_plan(self, domain_pddl: str, problem_pddl: str):
        domain_pddl_path = as_file(domain_pddl)
        problem_pddl_path = as_file(problem_pddl)
        temp_plan_path = get_random_temp_file_name()
        temp_sas_path = get_random_temp_file_name()
        output = subprocess.run(
            [
                "python3",
                self.fd_py_path,
                "--alias",
                self.fd_alias,
                "--search-time-limit",
                f"{self.fd_search_time_limit}",
                "--plan-file",
                temp_plan_path,
                "--sas-file",
                temp_sas_path,
                domain_pddl_path,
                problem_pddl_path
            ],
            capture_output=True,
            text=True,
            universal_newlines=True,
        )

        search_output = output.stdout
        search_error = output.stderr
        read_and_remove_file(domain_pddl_path)
        read_and_remove_file(problem_pddl_path)
        print(output)
        # pdb.set_trace()
        if "Solution found." in search_output:
            plan = postprocess(read_and_remove_file(temp_plan_path))
            return plan, True, "Solution found."
        elif "Search stopped without finding a solution." in search_output:
            return None, True, "Generated PDDL domain is valid, but plan search stopped without finding a solution."
        elif "Time limit has been reached." in search_output:
            return None, True, "Generated PDDL domain is valid, but search Time limit has been reached."
        else:
            # pdb.set_trace()
            return None, False, search_error

    def validate_plan(self, domain_pddl: str, problem_pddl: str, plan: str):
        domain_pddl_path = as_file(domain_pddl)
        problem_pddl_path = as_file(problem_pddl)
        plan_file = as_file(plan)
        val_output = subprocess.run(
            [
                self.val_bin_path,
                "-v",
                domain_pddl_path,
                problem_pddl_path,
                plan_file
            ],
            capture_output=True,
            text=True,
            universal_newlines=True,
        )
        print(val_output)
        read_and_remove_file(domain_pddl_path)
        read_and_remove_file(problem_pddl_path)
        is_valid, val_message = self._parse_val_output(val_output.stdout)
        return is_valid, val_message

    def get_random_walk_plan(
            self, domain_pddl: str, problem_pddl: str, predicate_descriptor_fn, max_steps: int
    ):
        seed = np.random.randint(2 ** 32 - 1)
        while True:
            func_result = safe_function_execute(
                self._get_random_walk_plan, domain_pddl, problem_pddl, predicate_descriptor_fn, max_steps, seed
            )
            if func_result is not None:
                plan, state_descs = func_result
                return plan, state_descs

    def _get_random_walk_plan(
            self, domain_pddl: str, problem_pddl: str, predicate_descriptor_fn, max_steps: int, seed
    ):
        rng = np.random.default_rng(seed)
        problem_pddl = get_problem_pddl_empty_goal(problem_pddl)
        lib = fast_downward.load_lib()
        task, sas = fast_downward.pddl2sas(domain_pddl, problem_pddl)
        lib.load_sas(sas.encode('utf-8'))

        plan, state_descs = [], []
        if predicate_descriptor_fn is not None:
            state_descs.append(self._get_state_natural_language(lib, predicate_descriptor_fn, action_name=None)) # state_descs is the feedback
        for _ in range(max_steps):
            available_actions = self._get_applicable_actions(lib) # they need this step for the original implementation??
            available_action_names = list(available_actions.keys())
            if len(available_action_names) == 0:
                break
            action_name = rng.choice(available_action_names)
            plan.append(action_name)
            action = available_actions[action_name]
            if predicate_descriptor_fn is not None:
                state_descs.append(self._get_state_natural_language(lib, predicate_descriptor_fn, action_name))
            _ = self._apply_action(lib, action)
        close_lib(lib)
        return plan, state_descs

    def get_plan_execution_feedback(
            self, domain_pddl: str, problem_pddl: str, plan: List[str], state_descs,
            predicate_descriptor_fn
    ):
        while True:
            feedback = safe_function_execute(
                self._get_plan_execution_feedback, domain_pddl, problem_pddl, plan, state_descs, predicate_descriptor_fn
            )
            # feedback = self._get_plan_execution_feedback(domain_pddl, problem_pddl, plan, state_descs, predicate_descriptor_fn)
            if feedback is not None:
                return feedback

    def _get_plan_execution_feedback(
            self, domain_pddl: str, problem_pddl: str, plan: List[str], state_descs: List[str], predicate_descriptor_fn
    ):
        assert state_descs is not None or predicate_descriptor_fn is not None, "Either state_descs or predicate_descriptor_fn must be provided."
        problem_pddl = get_problem_pddl_empty_goal(problem_pddl)
        lib = fast_downward.load_lib()
        task, sas = fast_downward.pddl2sas(domain_pddl, problem_pddl)
        lib.load_sas(sas.encode('utf-8'))
        plan_so_far = []
        feedback = "The plan is executable."
        executable = True
        for action_name in plan:
            plan_so_far.append(f"({action_name})")
            available_actions = self._get_applicable_actions(lib)
            available_action_names = list(available_actions.keys())
            print(action_name, available_action_names)
            # pdb.set_trace()
            if action_name not in available_action_names:
                # pdb.set_trace()
                if state_descs is not None:
                    state_desc_str = state_descs[len(plan_so_far) - 1]
                    if len(state_desc_str) > 1000:
                        state_desc_str = state_desc_str[:1000] + "..."
                        logging.warning(f"State description is too long: {state_desc_str}, truncating.")
                    feedback = (f"Error when executing the action ({action_name}).\n"
                                f"Current state: {state_desc_str}\n"
                                f"This action is not executable on the environment.")
                elif predicate_descriptor_fn is not None:
                    feedback = (f"Error when executing the action ({action_name}).\n"
                                f"Current state: {self._get_state_natural_language(lib, predicate_descriptor_fn, action_name)}\n"
                                f"This action is executable on the environment, but your generated environment recognizes this as an illegal action.")
                else:
                    feedback = f"Error when executing the action ({action_name}). This action is not executable on the environment."
                executable = False
                break
            else:
                action = available_actions[action_name]
                _ = self._apply_action(lib, action)

        exec_description = f"Executing the following actions sequentially on the environment:\n{self.plan_to_str(plan_so_far)}\n\nResult: "
        close_lib(lib)
        return executable, f"{exec_description}{feedback}"

    def _get_applicable_actions(self, lib) -> dict:
        operator_count = lib.get_applicable_operators_count()
        operators = (Operator * operator_count)()
        lib.get_applicable_operators(operators)
        return {op.name: op for op in operators}

    def _apply_action(self, lib, action):
        effects = (Atom * action.nb_effect_atoms)()
        lib.apply_operator(action.id, effects)
        return effects

    def _get_state_natural_language(self, lib, predicate_desc_fn, action_name=None):
        if action_name is None:
            relevant_facts = self._get_all_atom_facts(lib)
        else:
            relevant_facts = self._get_relevant_atom_facts(lib, action_name)

        fact_descriptions = []
        for i in range(len(relevant_facts)):
            is_not, atom_name, fact_args = relevant_facts[i]
            fact_descriptions.append(
                predicate_desc_fn(atom_name, fact_args)[1 if is_not else 0]
            )  # 0 for positive, 1 for negative
        return " ".join(fact_descriptions)

    def _get_relevant_atom_facts(self, lib, action_name):
        action_params = action_name.split()[1:]
        relevant_facts = []
        for (is_not, atom_name, fact_args) in self._get_all_atom_facts(lib):
            if len(fact_args) == 0:
                relevant_facts.append((is_not, atom_name, fact_args))
            if set(fact_args).issubset(set(action_params)):
                relevant_facts.append((is_not, atom_name, fact_args))
        return relevant_facts

    def _get_all_atom_facts(self, lib):
        state_size = lib.get_state_size()
        atoms = (Atom * state_size)()
        lib.get_state(atoms)
        atom_names = list(set(map(str, atoms)))
        atom_names = [x.replace("NegatedAtom ", "not ").replace("Atom ", "") for x in atom_names]
        # 'new-axiom@0' is a special atom that is not relevant to the user
        atom_names = [x for x in atom_names if 'new-axiom@0' not in x]
        atoms_parsed = [
            (is_not, atom_name, fact_args) for is_not, atom_name, fact_args in map(extract_atom_arguments, atom_names)
        ]
        return atoms_parsed

    def _parse_val_output(self, val_output: str):
        plan_val_text = "Plan Validation details\n-----------------------"
        if "Plan valid" in val_output:
            return True, "The plan is valid."
        elif plan_val_text in val_output:
            val_output = val_output.split(plan_val_text)[1].strip()
            return False, val_output
        elif "Goal not satisfied" in val_output:
            return False, "The goal is not satisfied."
        elif "Plan invalid" in val_output:
            return False, "The plan is invalid."
        else:
            logging.info("Unknown validation output: " + val_output)
            return False, "Unknown error."

    def plan_to_str(self, plan):
        if isinstance(plan, list):
            return "\n".join(plan)
        else:
            return str(plan)
