import pdb
import json
import sys

import fire
import gradio as gr
import torch
import transformers
from peft import (
    LoraConfig,
    get_peft_model,
    get_peft_model_state_dict,
    # prepare_model_for_int8_training,
    prepare_model_for_kbit_training,
    set_peft_model_state_dict,
)
from transformers import LlamaForCausalLM, LlamaTokenizer, GPT2Tokenizer
from peft import PeftModel

import torch.nn.functional as F
import os
import torch.nn as nn
import numpy as np

from critic.critic import Critic
from torch.distributions.categorical import Categorical
import copy 

root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(root)

sys.path.append(
    os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)))


def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer

class LLMAgent(nn.Module):
    def __init__(self, normalization_mode = 'token', load_path = None, load_8bit = False, task = 3):
        super().__init__()

        self.load_8bit = load_8bit
        # self.base_model = 'Neko-Institute-of-Science/LLaMA-7B-HF'
        # self.base_model = "/lustre/S/tianzikang/LLMs/llama2/llama2-13b-chat-hf/"
        self.base_model = "/lustre/S/tianzikang/LLMs/tiny_llama/TinyLlama-1.1B-intermediate-step-1431k-3T/"
        # self.base_model = "/lustre/S/tianzikang/LLMs/lite_llama/LiteLlama-460M-1T/"
        self.lora_r  = 8
        self.lora_alpha = 16
        self.lora_dropout = 0
        self.lora_target_modules  = ["q_proj", "v_proj",]

        self.task = task

        assert (
            self.base_model
        ), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'"

        if torch.cuda.is_available():
            self.device = "cuda"
        else:
            self.device = "cpu"

        try:
            if torch.backends.mps.is_available():
                self.device = "mps"
        except:  # noqa: E722
            pass

        self.normalization_mode = normalization_mode

        if "lite_llama" in self.base_model:
            self.tokenizer = GPT2Tokenizer.from_pretrained(self.base_model)
        else:
            self.tokenizer = LlamaTokenizer.from_pretrained(self.base_model)
        self.tokenizer.pad_token_id = (
            0  # unk. we want this to be different from the eos token
        )

        self.llama = self._init_llama()

        if load_path:
            load_critic = self.load(load_path)
            if not load_critic:
                # when there is no critic to load, we need to initialize a new critic
                self.critic = self._init_critic().to(self.device)
        else:
            self.actor = self._init_actor().to(self.device)
            self.critic = self._init_critic().to(self.device)

    def _init_llama(self):
        model = LlamaForCausalLM.from_pretrained(
            self.base_model,
            torch_dtype=torch.float16,
            load_in_8bit=self.load_8bit,
            device_map="auto",
            cache_dir=os.path.join(root, 'weights/llama')
            #cache_dir='weights/llama'
        )

        if not self.load_8bit:
            model.half().to(self.device)
        else:
            # model = prepare_model_for_int8_training(model, use_gradient_checkpointing=True)
            model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)
            
        model.forward

        return model

    def _init_actor(self, lora_weights = None):
        if lora_weights is None:
            config = LoraConfig(
                r=self.lora_r,
                lora_alpha=self.lora_alpha,
                target_modules=self.lora_target_modules,
                lora_dropout=self.lora_dropout,
                bias="none",
                task_type="CAUSAL_LM",
            )
            model = get_peft_model(self.llama, config)

            model.print_trainable_parameters()

            old_state_dict = model.state_dict
            model.state_dict = (
                lambda self, *_, **__: get_peft_model_state_dict(
                    self, old_state_dict()
                )
            ).__get__(model, type(model))
        else:
            model = PeftModel.from_pretrained(
                self.llama,
                lora_weights,
                torch_dtype=torch.float16,
            )

        if torch.__version__ >= "2" and sys.platform != "win32":
            model = torch.compile(model)

        if not self.load_8bit:
            model.half()
        else:
            # model = prepare_model_for_int8_training(model, use_gradient_checkpointing=True)
            model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)

        return model

    def _init_critic(self, critic_weights = None):
        critic = Critic(self.actor, self.tokenizer)
        if critic_weights is not None:
            critic.v_head.load_state_dict(torch.load(critic_weights, map_location= "cpu"))
        return critic


    def save(self, epoch, exp_path):
        print("save model")
        exp_path = os.path.join(exp_path, "epoch_{:04d}".format(epoch))

        os.makedirs(exp_path, exist_ok=True)
        # save lora
        self.actor.save_pretrained(exp_path)
        # save critic
        # torch.save(self.critic.v_head.state_dict(), os.path.join(exp_path, "critic.pth"))

    def load(self, exp_path):
        print("load model")
        lora_weights = exp_path
        critic_weights = os.path.join(exp_path, "critic.pth")
        self.actor = self._init_actor(lora_weights).to(self.device)
        if os.path.exists(critic_weights):
            self.critic = self._init_critic(critic_weights).to(self.device)
            return True
        return False
        # self.critic = self._init_critic(critic_weights).to(self.device)
    
    def get_value(self, x):
        if type(x[0]) == dict:
            # try:
            x = [o["prompt"] for o in x]
            # except:
            #     pdb.set_trace()
        inputs = self.tokenizer(x, return_tensors="pt", padding=True)
        input_ids = inputs["input_ids"].to(self.device)
        attention_mask = inputs["attention_mask"].to(self.device)
        
        with self.actor.disable_adapter():
            # pdb.set_trace()
            value = self.critic(input_ids, attention_mask=attention_mask)
        return value

    def get_action_and_value(self, obs, action=None, is_warmup=False, return_value = True, candidate_plans=None):
        # text_obs = [self.obs2text(o) for o in obs] # a batch data
        prompt = [o["prompt"] for o in obs]
        action_list = [o["action"] for o in obs]
        
        prompt_num = len(prompt)
        action_num = len(action_list[0])

        sequence = [] # a batch data, with each data is format like "prompt a_1, a_2, ..."
        for p, ac in zip(prompt, action_list):
            sequence += [p + " " + a for a in ac]
            
        # pdb.set_trace()

        inputs = self.tokenizer(sequence, return_tensors="pt", padding=True)
        input_ids = inputs["input_ids"].to(self.device)
        
        attention_mask = inputs["attention_mask"].to(self.device)
        if is_warmup:
            with torch.no_grad():
                outputs = self.actor(input_ids, attention_mask=attention_mask)
        else:
            # pdb.set_trace()
            outputs = self.actor(input_ids, attention_mask=attention_mask)
        
        # transform action_list from [[a_1, a_2], [a_3, a_4]] to [a_1, a_2, a_3, a_4]
        action_list = [item for sublist in action_list for item in sublist]
        self.action_list_ids = self.tokenizer(action_list, return_tensors="pt", padding=True)

        # why delete first token????????????????????????????????????????????????????????????????????????????
        self.action_list_length = torch.sum(self.action_list_ids["attention_mask"], dim = -1) - 1 #delete first token
        sequence_length = torch.sum(attention_mask, dim = -1)
        action_index = [[end - start, end] for start, end in zip(self.action_list_length, sequence_length)]

        # maybe no need to use it, directly use logits
        logits = torch.log_softmax(outputs.logits, dim=-1)

        logits = logits[:, :-1, :]
        input_ids = input_ids[:, 1:]
        gen_logits = torch.gather(logits, 2, input_ids[:, :, None]).squeeze(-1)

        slices = [gen_logits[i, start-1:end-1] for i, (start, end) in enumerate(action_index)]
        
        action_logits = torch.stack([torch.sum(s) for s in slices])
        if self.normalization_mode == 'token':
            action_logits = action_logits / self.action_list_length.to(self.device)
        elif self.normalization_mode == 'word':
            action_word_num = torch.tensor([len(action.split()) for action in action_list]).to(self.device)
            action_logits = action_logits / action_word_num
        elif self.normalization_mode == 'sum':
            action_logits = action_logits
        else:
            assert 1==2

        action_logits = action_logits.reshape(-1, action_num).float()

        probs = Categorical(logits=action_logits)
        if action is None:
            action = probs.sample()
            
        # pdb.set_trace()

        if return_value:
            return action, probs.log_prob(action), probs.entropy(), self.get_value(prompt), probs.probs
        else:
            return action, probs.log_prob(action), probs.entropy(), None, probs.probs

    def get_action_and_value_optim(self, obs, action=None, is_warmup=False, return_value = True, candidate_plans=None):
        # text_obs = [self.obs2text(o) for o in obs] # a batch data
        prompt = [o["prompt"] for o in obs]
        action_list = [o["action"] for o in obs]
        
        prompt_num = len(prompt)
        action_num = len(action_list[0])

        inputs = self.tokenizer(prompt, return_tensors="pt", padding=True)
        input_ids = inputs["input_ids"].to(self.device)
        
        attention_mask = inputs["attention_mask"].to(self.device)
        if is_warmup:
            with torch.no_grad():
                outputs = self.actor(input_ids, attention_mask=attention_mask, use_cache=True)
        else:
            # pdb.set_trace()
            outputs = self.actor(input_ids, attention_mask=attention_mask, use_cache=True)
        
        _action_list = [item for sublist in action_list for item in sublist]
        action_list_tokens = self.tokenizer(_action_list, return_tensors="pt", padding=True)
        action_list_ids = action_list_tokens["input_ids"][..., 1:]
        dim_1 = int(action_list_ids.shape[0] / action_num)
        # pdb.set_trace()
        action_list_ids = action_list_ids.view(dim_1, action_num, -1).to(self.device)
        action_list_attention_mask = action_list_tokens["attention_mask"][..., 1:].view(dim_1, action_num, -1).to(self.device)
        
        action_list_length = torch.sum(action_list_tokens["attention_mask"], dim = -1) - 1 #delete first token
        
        outputs_action = torch.stack(
            [self.actor(action_list_ids[:, i, :], attention_mask=torch.cat([attention_mask, action_list_attention_mask[:, i, :]], dim=-1), past_key_values=outputs.past_key_values, use_cache=True).logits for i in range(action_num)],
            dim=1
        )

        # maybe no need to use it, directly use logits
        # temp = torch.stack([outputs.logits[:, :-1, :] for _ in range(action_num)], dim=1)
        temp = torch.stack([outputs.logits[:, -1:, :] for _ in range(action_num)], dim=1)
        logits = torch.log_softmax(torch.cat([temp, outputs_action], dim=-2), dim=-1) # (bs, action_num, token_num, all_tokens_num)
        gen_logits = torch.gather(logits, 3, action_list_ids[:, :, :, None]).squeeze(-1) # (bs, action_num, token_num)
        
        action_logits = torch.sum(gen_logits * action_list_attention_mask, dim=-1) # (bs, action_num)
        
        # action_logits = torch.stack([torch.sum(s) for s in slices])
        if self.normalization_mode == 'token':
            action_logits = action_logits / action_list_length.to(self.device)
        elif self.normalization_mode == 'word':
            action_word_num = torch.tensor([[len(action.split()) for action in a_list] for a_list in action_list]).to(self.device)
            # try:
            action_logits = action_logits / action_word_num
            # except Exception as e:
            #     pdb.set_trace()
        elif self.normalization_mode == 'sum':
            action_logits = action_logits
        else:
            assert 1==2

        action_logits = action_logits.reshape(-1, action_num).float()

        probs = Categorical(logits=action_logits)
        if action is None:
            action = probs.sample()
            
        # pdb.set_trace()

        if return_value:
            return action, probs.log_prob(action), probs.entropy(), self.get_value(prompt), probs.probs
        else:
            return action, probs.log_prob(action), probs.entropy(), None, probs.probs
        
    def get_action_and_value_parallel_less_mem(self, obs, action=None, is_warmup=False, return_value = True, candidate_plans=None):
        # text_obs = [self.obs2text(o) for o in obs] # a batch data
        prompt = [o["prompt"] for o in obs]
        action_list = [o["action"] for o in obs]
        
        prompt_num = len(prompt)
        action_num = len(action_list[0])

        sequence = [] # a batch data, with each data is format like "prompt a_1, a_2, ..."
        for p, ac in zip(prompt, action_list):
            sequence += [p + " " + a for a in ac]
            
        # pdb.set_trace()

        inputs = self.tokenizer(sequence, return_tensors="pt", padding=True)
        input_ids = inputs["input_ids"].to(self.device)
        
        attention_mask = inputs["attention_mask"].to(self.device)
        if is_warmup:
            with torch.no_grad():
                outputs = self.actor(input_ids, attention_mask=attention_mask)
        else:
            # pdb.set_trace()
            outputs = self.actor(input_ids, attention_mask=attention_mask)
        
        # transform action_list from [[a_1, a_2], [a_3, a_4]] to [a_1, a_2, a_3, a_4]
        action_list = [item for sublist in action_list for item in sublist]
        self.action_list_ids = self.tokenizer(action_list, return_tensors="pt", padding=True)

        # why delete first token????????????????????????????????????????????????????????????????????????????
        self.action_list_length = torch.sum(self.action_list_ids["attention_mask"], dim = -1) - 1 #delete first token
        sequence_length = torch.sum(attention_mask, dim = -1)
        action_index = [[end - start, end] for start, end in zip(self.action_list_length, sequence_length)]

        # maybe no need to use it, directly use logits
        logits = torch.log_softmax(outputs.logits, dim=-1)

        logits = logits[:, :-1, :]
        input_ids = input_ids[:, 1:]
        gen_logits = torch.gather(logits, 2, input_ids[:, :, None]).squeeze(-1)

        slices = [gen_logits[i, start-1:end-1] for i, (start, end) in enumerate(action_index)]
        
        action_logits = torch.stack([torch.sum(s) for s in slices])
        if self.normalization_mode == 'token':
            action_logits = action_logits / self.action_list_length.to(self.device)
        elif self.normalization_mode == 'word':
            action_word_num = torch.tensor([len(action.split()) for action in action_list]).to(self.device)
            action_logits = action_logits / action_word_num
        elif self.normalization_mode == 'sum':
            action_logits = action_logits
        else:
            assert 1==2

        action_logits = action_logits.reshape(-1, action_num).float()

        probs = Categorical(logits=action_logits)
        if action is None:
            action = probs.sample()
            
        # pdb.set_trace()

        if return_value:
            return action, probs.log_prob(action), probs.entropy(), self.get_value(prompt), probs.probs
        else:
            return action, probs.log_prob(action), probs.entropy(), None, probs.probs

    
def preprocess_info_v1(messages):
    # pdb.set_trace()
    # prompt
    
    # information, plans = self.preprocess_info(messages)
    # task = messages["task"]
    skills = messages["chosen_skills"]
    skills = [item if item else "do nothing" for item in skills]
    obs = messages["text_obs"]
    task = messages["task"]
    
    n_plans = len(skills)
    prompt = f"Now, which of these {n_plans} plans is the most effective for accomplishing the task?"
    
    # effective obs text
    obs_text = f"You are a household agent. {task}\n"
    # obs_text += f"```json\n{json.dumps(task, indent=4)}\n```\n"
    # pdb.set_trace()
    obs_text += "The current situation is:\n"
    # obs_text += f"```json\n{json.dumps(information, indent=4)}\n```\n"
    obs_text += obs
    obs_text += "\n"
    skills_str = "\n".join(skills)
    obs_text += f"Now there are {n_plans} plans:\n{skills_str}\n"
        
    obs_text += prompt
    action_list = skills
    
    # pdb.set_trace()
    
    return obs_text, action_list

def obs2text(obs):
    # obs is a dict
    obs_text, action_list = preprocess_info_v1(obs)
    # obs_text, action_list = self.preprocess_info_v2(obs)
    # pdb.set_trace()
    return {"prompt": obs_text, "action": action_list}
