from typing import Union, List, Tuple
import time
import torch
import numpy as np

from embodied_cd.common.dataset_utils import PromptTemplate
from embodied_cd.common.agent import BaseAgent
from embodied_cd.common.mixin import FewShotMixIn
from embodied_cd.trl.models.core import (
    _Type_Decoding,
    generation, 
    greedy_generation, 
    beam_action_generation,
)
from embodied_cd.common.print_utils import *
from embodied_cd.agents.ecoc import ECoCAgent


class ThinkThreeStepAgent(ECoCAgent):
    name = "ecoc_think_three_step"
    def __init__(
        self, 
        pre_model_name,
        base_model,
        base_tokenizer,
        model, 
        tokenizer,
        plan_model=None, 
        plan_tokenizer=None,
        rag_pipe=None,
        correct: bool = False,
        no_critic: bool = False,
        env_name: str = 'virtualhome',
        cl_type: str = 'behavior',
        max_think_token: int = 60,
        total_think: int = 2,
        few_shot_example: str = None,
        perturb: bool = False,
        decoding_strategy: _Type_Decoding = 'beam-action'
    ):
        super().__init__(
            pre_model_name, base_model, base_tokenizer, model, tokenizer, plan_model, plan_tokenizer, \
            rag_pipe, correct, no_critic, env_name, cl_type, max_think_token, total_think, few_shot_example, perturb, decoding_strategy)
    
    def get_think(
        self,
        model,
        tokenizer,
        instruction: str,
        state: str,
        history: str,
        sample: bool = False
    ):
        # Set Adapter
        model.set_adapter('reasoning_policy')
        
        think_list = []
        for j in range(self.total_think):
            # Generation
            query = f"Instruction: {instruction}\nState: {state}\nPrevious Actions: {history}\n{PromptTemplate.correct_think_prompts_step_3[j]}"
            query = self.querize(query)
            generation_output = self.generate(model, tokenizer, query, self.max_think_token, sample)

            # Post-processing
            think = generation_output.response
            think = think.strip().split("\n")[0]
            think = think.split(". ") 
            if j == 0:
                if len(think) > 2:
                    think = think[0] + ". " + think[1] + ". " + think[2]
                else:
                    think = think[0]
            elif j == 1:
                if len(think) > 1:
                    think = think[0] + ". " + think[1]
                else:
                    think = think[0]
            think = think + "." if think[-1] != "." else think
            think_list.append(think)
        return " ".join(think_list), think_list
