from typing import Union, List, Tuple
import copy
import time
import torch
import numpy as np

from embodied_cd.common.dataset_utils import PromptTemplate
from embodied_cd.common.agent import BaseAgent
from embodied_cd.common.mixin import FewShotMixIn
from embodied_cd.trl.models.core import (
    _Type_Decoding,
    generation, 
    greedy_generation, 
    beam_action_generation,
)
from embodied_cd.common.print_utils import *


class ECoCAgent(BaseAgent, FewShotMixIn):
    """ Agent for ECoC' """
    name = "ecoc"

    gen_params = {
        "do_sample": True,
        "top_k": 0,
        "top_p": 0.6,
        "temperature": 0.7
    }

    def __init__(
        self, 
        pre_model_name,
        base_model,
        base_tokenizer,
        model, 
        tokenizer,
        plan_model=None, 
        plan_tokenizer=None,
        rag_pipe=None,
        correct: bool = False,
        no_critic: bool = False,
        env_name: str = 'virtualhome',
        cl_type: str = 'behavior',
        max_think_token: int = 80,
        total_think: int = 5,
        few_shot_example: str = None,
        perturb: bool = False,
        decoding_strategy: _Type_Decoding = 'beam-action',
        test_time_thresh: dict = None,
    ):
        super().__init__()

        self.pre_model_name = pre_model_name

        self.env_name = env_name
        self.cl_type = cl_type
        
        self.base_model = base_model
        self.base_tokenizer = base_tokenizer
        self.model = model
        self.tokenizer = tokenizer
        self.max_think_token = max_think_token
        self.total_think = total_think
        self.think_template = PromptTemplate(env_name, "cd-think", few_shot_example)

        self.plan_model = plan_model
        self.plan_tokenizer = plan_tokenizer
        self.plan_template = PromptTemplate(env_name, "cd-action-think")
        self.action_format = PromptTemplate.load_env_action_format(env_name)

        self.correct = correct
        self.no_critic = no_critic
        self.perturb = perturb
        self.decoding_strategy = decoding_strategy

        self.rag_pipe = rag_pipe
        # for counting
        self.forward_count = 0
        self.correct_count = 0
        self.generated_tokens = [0, 0, 0, 0] # total think, think segment, action
        self.reasoning_counts = [0 for _ in range(total_think)] # 1, 2, 3, 4, 5
        
        # test_time thresholds
        self.test_time_thresh = test_time_thresh
        self.evaluate = None

    def reset(self, task, goal):
        return

    def querize(self, _query):
        if 'instruct' in self.pre_model_name or 'Instruct' in self.pre_model_name:
            query = self._convert_to_chat([_query], list())
            query = think_tokenizer.apply_chat_template(
                query, tokenize=False, add_generation_prompt=True)
        else:
            #query = self._convert_to_completion(_query)
            query = f"### Human: {_query}\n### Assistant:"
        return query
    
    def generate(self, model, tokenizer, query, max_length, sample):
        # 2. Greedy Generation
        with torch.no_grad():
            if not sample:
                generation_output = greedy_generation(
                    model, tokenizer, query, max_length)
            else:
                generation_output = generation(
                    model, tokenizer, query, max_length, **self.gen_params) 
        return generation_output
    
    def get_token_length(self, query):
        return len(self.tokenizer.encode(query, return_tensors="np")[0])

    def get_feedback_think(
        self,
        model,
        tokenizer,
        instruction: str,
        state: str,
        history: str,
    ):
        self.forward_count += 1 

        print("="*40)
        think_list = []
        reasoning_trace = ""
        for j in range(self.total_think):
            # Generate Init Think
            model.set_adapter('reasoning_policy')
            query = f"Instruction: {instruction}\nState: {state}\nPrevious Actions: {history}\n{PromptTemplate.correct_think_prompts[j]}"
            query = self.querize(query)
            generation_output = self.generate(model, tokenizer, query, self.max_think_token, sample=False)
            think = generation_output.response
            think = think.strip().split("\n")[0]
            init_think = copy.deepcopy(think)
            self.generated_tokens[0] += self.get_token_length(think) # default generated tokens
            
            count = 0
            while True:
                # Get Feedback
                model.set_adapter('feedback_policy')
                query = f"Instruction: {instruction}\nState: {state}\nPrevious Actions: {history}\n{PromptTemplate.correct_think_prompts[j]}\nRationale: {think}"
                query = self.querize(query)
                generation_output = self.generate(model, tokenizer, query, 30, sample=False)
                feedback = generation_output.response
                feedback = feedback.strip().split("\n")[0]

                print(j, think)
                print_warn(feedback)
                if ('minor' in feedback) or (count >= 3) or (count == 1 and init_think == think):
                    think = init_think
                    break

                # Generate Init Think
                model.set_adapter('reasoning_policy')
                query = f"Instruction: {instruction}\nState: {state}\nPrevious Actions: {history}\nReasoning Trace: {reasoning_trace}\nThink: {think}\n{feedback} {PromptTemplate.correct_think_prompts[j]}"
                query = self.querize(query)
                generation_output = self.generate(model, tokenizer, query, self.max_think_token, sample=True)
                think = generation_output.response
                think = think.strip().split("\n")[0]

                self.reasoning_counts[j] += 1
                self.generated_tokens[1] += self.get_token_length(think) # additional generated tokens
                count += 1

            think_list.append(think)
            reasoning_trace = think if j == 0 else reasoning_trace + " " + think
        return " ".join(think_list), think_list

    def get_think(
        self,
        model,
        tokenizer,
        instruction: str,
        state: str,
        history: str,
        sample: bool = False
    ):
        # Set Adapter
        model.set_adapter('reasoning_policy')
        
        think_list = []
        for j in range(self.total_think):
            # Generation
            query = f"Instruction: {instruction}\nState: {state}\nPrevious Actions: {history}\n{PromptTemplate.correct_think_prompts[j]}"
            query = self.querize(query)
            generation_output = self.generate(model, tokenizer, query, self.max_think_token, sample)

            # Post-processing
            think = generation_output.response
            think = think.strip().split("\n")[0]
            think = think.split(". ")[0] 
            #think = think + "." if think[-1] != '.' else think
            think_list.append(think)
        return " ".join(think_list), think_list
    
    def correct_think(
        self,
        model,
        tokenizer,
        instruction: str,
        state: str,
        history: str,
        init_think: str,
        sample: bool = False,
    ):
        model.set_adapter('reasoning_policy')

        init_think_list = init_think.split(". ")
        init_think_list = [think + '.' if think[-1] != '.' else think for think in init_think_list]
        if len(init_think_list) != self.total_think:
            for _ in range(self.total_think - len(init_think_list)):
                init_think_list.append("")

        #### Retrieve RAG
        rag_score, rag_score_str, rag_action = self.rag_pipe.retrieve(
            instruction, state, history, init_think_list)
        print_error(rag_action, rag_score)
        ### 

        correct_think_list = []
        reasoning_trace = ""
        for j in range(self.total_think):
            query = f"Instruction: {instruction}\nState: {state}\nPrevious Actions: {history}\nReasoning Trace: {reasoning_trace}\nThink: {init_think_list[j]}\n{rag_score_str[j]} {PromptTemplate.correct_think_prompts[j]}"
            query = self.querize(query)
            generation_output = self.generate(model, tokenizer, query, self.max_think_token, sample)
            
            # Post-processing
            correct_think = generation_output.response
            correct_think = correct_think.strip().split("\n")[0]
            correct_think = correct_think.split(". ")[0] 
            correct_think = correct_think + "." if correct_think[-1] != '.' else correct_think
            correct_think_list.append(correct_think)
            reasoning_trace = correct_think if j == 0 else reasoning_trace + " " + correct_think
        return " ".join(correct_think_list), correct_think_list
           
    def get_action(
        self,
        instruction: str,
        state: str,
        history: str,
        think: str = None,
    ):
        # Set Adapter
        self.model.set_adapter('planning_policy')
        
        # Generation
        query = self.plan_template(instruction=instruction, state=state, think=think, history=history)["query"]
        query = self.querize(query)
        generation_output = self.generate(self.model, self.tokenizer, query, 20, False)

        # Post-processing
        action = generation_output.response
        log_prob = generation_output.log_probs
        action = action.strip().split("\n")[0]
        token_length = self.get_token_length(action) 
        log_prob = np.sum(log_prob[:token_length]) / token_length
        return action, log_prob
    
    def forward(
        self,
        instruction: str,
        state: str,
        history: str,
        few_shot_examples: Union[str, List[str]] = None,
    ):
        state = PromptTemplate.preprocess(state)
        if self.perturb:
            if self.env_name == 'virtualhome':
                state = PromptTemplate.randomize(state, 0.5)
            if self.env_name == 'alfred':
                state = PromptTemplate.randomize(state, 0.3)

        if not self.correct:
            if self.test_time_thresh is not None:
                think, think_list = self.get_think(self.model, self.tokenizer, instruction, state, history, sample=False)
                self.forward_count += 1
                self.generated_tokens[0] += self.get_token_length(think)
                print("=" * 30)
                for j in range(len(think_list)):
                    think_seg = " ".join(think_list[:j+1])
                    action, log_prob = self.get_action(instruction, state, history, think_seg)
                    self.generated_tokens[2] += self.get_token_length(action)
                    if self.evaluate == 'seen':
                        thresh = self.test_time_thresh['average'][j] #- 0.253 * self.test_time_thresh['deviation'][j] 
                    elif self.evaluate == 'unseen':
                        thresh = self.test_time_thresh['average'][j] + 0.253 * self.test_time_thresh['deviation'][j] 
                    else:
                        raise NotImplementedError
                    print_check(f"At {j+1}: {log_prob} | {thresh}")
                    if log_prob > thresh:
                        print_warn(f"Stop!!")
                        break
                self.generated_tokens[1] += self.get_token_length(think_seg)
                self.generated_tokens[3] += (j + 1)
                self.reasoning_counts[j] += 1
                print(f"Added ... {j+1}, {think_seg}")
            else:
                think, _ = self.get_think(self.model, self.tokenizer, instruction, state, history, sample=False)
                action, prob = self.get_action(instruction, state, history, think)
                print("=" * 30)
                print(think)
                print(action)
        else:
            #init_think, _ = self.get_think(self.base_model, self.base_tokenizer, instruction, state, history, sample=False)
            #correct_think, _ = self.correct_think(self.model, self.tokenizer, instruction, state, history, init_think, sample=False)
            correct_think, _ = self.get_feedback_think(self.model, self.tokenizer, instruction, state, history)
            action, _ = self.get_action(instruction, state, history, correct_think)
            print("=" * 30)
            print(correct_think)
            print(action)
        return action

    def get_correct_ratio(self):
        return self.correct_count / self.forward_count

    def get_generated_tokens(self):
        return np.array(self.generated_tokens) / self.forward_count

    def get_reasoning_counts(self):
        if self.correct:
            return np.array(self.reasoning_counts) / self.forward_count
        else:
            return self.reasoning_counts
