# adapt to the new environment
import json 
import dataclasses
from enum import auto, Enum
from typing import List, Tuple, Any

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig

from human_model import Conversation, SeparatorStyle, empty_conv

import peft 
from peft import PeftConfig, PeftModel
from peft.utils import _get_submodules

prefixes = {
    'pick_and_place': 'put',
    'pick_clean_then_place': 'clean',
    'pick_heat_then_place': 'heat',
    'pick_cool_then_place': 'cool',
    'look_at_obj': 'examine',
    'pick_two_obj': 'puttwo'
}

DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "</s>"
DEFAULT_UNK_TOKEN = "</s>"


class AgentModel():

    def __init__(self, prompt_path, model_path, tokenizer_path, peft_path):
        print("initializing agent model")
        self.conv = None
        self.prompt_str = None

        self.prompt_path = prompt_path
        self.model_path = model_path
        self.peft_path = peft_path
        

        if self.peft_path:
            free_in_GB = int(torch.cuda.mem_get_info()[0] / 1024 ** 3)
            max_memory = f'{free_in_GB-2}GB'
            n_gpus = torch.cuda.device_count()
            max_memory = {i: max_memory for i in range(n_gpus)}

            self.model = AutoModelForCausalLM.from_pretrained(model_path, return_dict=True, load_in_8bit=True, 
                                                        max_memory=max_memory, device_map='auto')

            # Load the Lora model
            self.model = PeftModel.from_pretrained(self.model, peft_path)
        else:
            free_in_GB = int(torch.cuda.mem_get_info()[0] / 1024 ** 3)
            max_memory = f'{free_in_GB-2}GB'
            n_gpus = torch.cuda.device_count()
            max_memory = {i: max_memory for i in range(n_gpus)}

            self.model = AutoModelForCausalLM.from_pretrained(model_path, return_dict=True, load_in_8bit=True, 
                                                        max_memory=max_memory, device_map='auto')

        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
        self.tokenizer.add_special_tokens(
                {
                    "eos_token": "</s>",
                    "bos_token": "</s>",
                    "unk_token": "</s>",
                    'pad_token': '[PAD]'
                }
            )
        if not ('gpt2' in model_path):
            self.tokenizer.model_max_length = 2048 - 150
        else:
            self.tokenizer.model_max_length = 1024 - 150
        
        print(self.tokenizer.model_max_length)
        self.tokenizer.truncation_side = 'left'
    
    def init_prompt(self, name):
        if self.peft_path:
            return
        with open(self.prompt_path, 'r') as f:
            d = json.load(f)

        prompt_str = 'Interact with a household to solve a task. Here are two examples: '
        for i, (k, v) in enumerate(prefixes.items()):
            if name.startswith(k):
                print(k, v)
                try:
                    for ii in range(100):
                        prompt_str += " > " + d[f'ask_think_{v}_0'][str(ii + 1)] 
                except KeyError:
                    pass
                prompt_str += " $$$ "
                try:
                    for ii in range(100):
                        prompt_str += " > " + d[f'ask_think_{v}_1'][str(ii + 1)] 
                except KeyError:
                    pass
                    

        self.prompt_str = prompt_str

    @torch.inference_mode()
    def act_greedy_search(self, ob, _, admissible_actions):

        try:   

            choice_ids = self.tokenizer(admissible_actions, padding=True, return_tensors="pt", add_special_tokens=True).input_ids.to(self.model.device)
            cnt = 1
            select_action = []
            temperature = 0.7

            # self.conv.append_message(self.conv.roles[0], ob)
            self.conv = empty_conv.copy()
            self.conv.append_message(self.conv.roles[0], self.prompt_str)
            self.conv.append_message(self.conv.roles[1], ob + " > ")

            if self.peft_path:
                input_str = ob + " > "
            else:
                input_str = self.conv.get_prompt()[:-3]

            inputs = self.tokenizer([input_str], truncation=True)


            while len(choice_ids) > 1:
                if cnt == 1:
                    with torch.no_grad():
                        out = self.model(torch.as_tensor(inputs.input_ids).to(self.model.device), 
                                            use_cache=True, return_dict=True)

                    logits = out.logits
                    past_key_values = out.past_key_values
                else:
                    attention_mask = torch.ones(
                        1, past_key_values[0][0].shape[-2] + 1).to(self.model.device)
                    with torch.no_grad():
                        out = self.model(input_ids=torch.as_tensor([[token]]).to(self.model.device),
                                attention_mask=attention_mask,
                                past_key_values=past_key_values,
                                use_cache=True)
                    
                    logits = out.logits
                    past_key_values = out.past_key_values
                
                last_token_logits = logits[0][-1]
                probs = torch.softmax(last_token_logits / temperature, dim=-1).to(self.model.device)  # (vocab_size,)

                if 'gpt2' in self.model_path:
                    choice_ids_unique = torch.unique(choice_ids[:, cnt - 1]).cuda()
                else:
                    choice_ids_unique = torch.unique(choice_ids[:, cnt]).to(self.model.device)

                next_move_prob = torch.gather(probs, 0, choice_ids_unique)
                next_move = choice_ids_unique[next_move_prob.argmax().item()]

                left_choice_idx = (choice_ids[:, cnt] == next_move).nonzero()[:, 0].to(self.model.device)
                left_choice_ids = choice_ids[left_choice_idx].to(self.model.device)

                # input_str += " " + self.tokenizer.decode([next_move.item(),]) 
                choice_ids = left_choice_ids

                cnt += 1
                select_action.append(next_move.item(),)

                token = next_move.item()

            select_action = self.tokenizer.decode(select_action, skip_special_tokens=True)

            final_choice = None
            for p in admissible_actions:
                if p.startswith(select_action.strip()):
                    final_choice = p
                    break 

            assert final_choice

            del past_key_values

            if final_choice.startswith("think") or final_choice.startswith("ask"):

                input_str += final_choice + ":"
                inputs = self.tokenizer([input_str], truncation=False)
                
                raw_input_ids = torch.as_tensor(inputs.input_ids).to(self.model.device)
                truncated_input_ids = raw_input_ids[:, -self.tokenizer.model_max_length:]
                tokens_left = raw_input_ids[:, :-self.tokenizer.model_max_length]

                with torch.no_grad():
                    output_ids = self.model.generate(
                            input_ids=torch.as_tensor(truncated_input_ids).to(self.model.device),
                            do_sample=False,
                            # temperature=0.7,
                            max_new_tokens=100,
                            # num_beams=1,
                            eos_token_id=[1405, self.tokenizer.eos_token_id]
                            )

                full_token = torch.cat([tokens_left, output_ids], dim=-1)

                outputs = self.tokenizer.batch_decode(full_token, skip_special_tokens=True)[0]
                outputs = outputs[len(input_str):].split('>')[0].strip().split('\n')[0].strip()
                final_choice = final_choice + ": " + outputs
        except:
            print("get action error")
            from IPython import embed; embed()
        return final_choice
