from typing import Union, List, Tuple
import time
import torch
import numpy as np

from embodied_cd.common.dataset_utils import PromptTemplate
from embodied_cd.common.agent import BaseAgent
from embodied_cd.common.mixin import FewShotMixIn
from embodied_cd.trl.models.core import (
    _Type_Decoding,
    generation, 
    greedy_generation, 
    beam_action_generation,
)
from embodied_cd.common.print_utils import *
from embodied_cd.agents.ecoc import ECoCAgent


class CAMAAgent(ECoCAgent):
    name = "ecoc_think_whole"
    def __init__(
        self, 
        pre_model_name,
        base_model,
        base_tokenizer,
        model, 
        tokenizer,
        plan_model=None, 
        plan_tokenizer=None,
        rag_pipe=None,
        correct: bool = False,
        no_critic: bool = False,
        env_name: str = 'virtualhome',
        cl_type: str = 'behavior',
        max_think_token: int = 200,
        total_think: int = 1,
        few_shot_example: str = None,
        perturb: bool = False,
        decoding_strategy: _Type_Decoding = 'beam-action'
    ):
        super().__init__(
            pre_model_name, base_model, base_tokenizer, model, tokenizer, plan_model, plan_tokenizer, \
            rag_pipe, correct, no_critic, env_name, cl_type, max_think_token, total_think, few_shot_example, perturb, decoding_strategy         )

    def get_action(
        self,
        instruction: str,
        state: str,
        history: str,
        think: str = None,
    ):
        self.model.set_adapter('planning_policy')

        query = self.plan_template(
            instruction=instruction, state=state, think=think, history=history
        )
        query = self._convert_to_completion(query["query"])

        with torch.no_grad():
            generation_output = greedy_generation(
                self.model, self.tokenizer, query, max_length=20
            )

        action = generation_output.response
        action = action.strip().split("\n")[0]
        return action, None
    
