from typing import Union, List, Tuple
import time
import torch
import numpy as np

from embodied_cd.common.dataset_utils import PromptTemplate
from embodied_cd.common.agent import BaseAgent
from embodied_cd.common.mixin import FewShotMixIn
from embodied_cd.trl.models.core import (
    _Type_Decoding,
    generation, 
    greedy_generation, 
    beam_action_generation,
)
from embodied_cd.common.print_utils import *
from embodied_cd.agents.ecoc import ECoCAgent


class ThinkWholeAgent(ECoCAgent):
    name = "ecoc_think_whole"
    def __init__(
        self, 
        pre_model_name,
        base_model,
        base_tokenizer,
        model, 
        tokenizer,
        plan_model=None, 
        plan_tokenizer=None,
        rag_pipe=None,
        correct: bool = False,
        no_critic: bool = False,
        env_name: str = 'virtualhome',
        cl_type: str = 'behavior',
        max_think_token: int = 200,
        total_think: int = 1,
        few_shot_example: str = None,
        perturb: bool = False,
        decoding_strategy: _Type_Decoding = 'beam-action'
    ):
        super().__init__(
            pre_model_name, base_model, base_tokenizer, model, tokenizer, plan_model, plan_tokenizer, \
            rag_pipe, correct, no_critic, env_name, cl_type, max_think_token, total_think, few_shot_example, perturb, decoding_strategy)
    
    def get_think(
        self,
        model,
        tokenizer,
        instruction: str,
        state: str,
        history: str,
        sample: bool = False
    ):
        # Set Adapter
        model.set_adapter('reasoning_policy')
        
        # Generation
        query = f"Instruction: {instruction}\nState: {state}\nPrevious Actions: {history}\n{PromptTemplate.correct_think_prompt}"
        query = self.querize(query)
        generation_output = self.generate(model, tokenizer, query, self.max_think_token, sample)

        # Post-processing
        think = generation_output.response
        think = think.strip().split("\n")[0]
        return think, [think]
