import functools
from mcu.arm.models.agents.base_agent import BaseAgent
import torch
import re
import av
import cv2
import numpy as np
import typing
import pdb
from typing import Union, Dict, Optional, List, Tuple, Any

from mcu.arm.utils.vpt_lib.action_head import ActionHead

from omegaconf import DictConfig
from gymnasium.spaces.space import Space
from mcu.arm.models.policys import make_policy, load_policy_cfg

def tree_get(obj, keys: List, default=None):
    try:
        for key in keys:
            if key in obj:
                obj = obj[key]
            else:
                return default
        return obj
    except:
        return default

class ConditionedAgent(BaseAgent):
    def __init__(
        self, 
        obs_space: Space, 
        action_space: Space, 
        policy_config: Union[DictConfig, str]
    ) -> None:
        super().__init__()
        self.obs_space = obs_space
        self.action_space = action_space
        self.policy_config = policy_config

        if isinstance(self.policy_config, str):
            self.policy_config = load_policy_cfg(self.policy_config)
            
        self.policy, self.policy_building_info = make_policy(policy_cfg=self.policy_config, action_space=self.action_space)

        self.timesteps = tree_get(
            obj=self.policy_config, 
            keys=['policy_kwargs', 'timesteps'], 
            default=128
        )
        
        self.cached_init_states = {}
        self.cached_first = {}
        

    def wrapped_forward(self, 
                        obs: Dict[str, Any], 
                        state_in: Optional[List[torch.Tensor]],
                        first: Optional[torch.Tensor] = None, 
                        **kwargs
    ) -> Tuple[Dict[str, torch.Tensor], List[torch.Tensor], Dict[str, Any]]:
        '''Wrap state and first arguments if not specified. '''
        B, T, W, H = obs['img'].shape[:4]

        state_in = self.initial_state(B) if state_in is None else state_in
        
        if self.policy.is_conditioned() and 'obs_conf' in obs:
            ice_latent = self.load_input_condition(obs_conf=obs['obs_conf'], resolution=(W, H))
        else:
            ice_latent = None
        
        if first is None:
            first = self.cached_first.get((B, T), torch.tensor([[False]], device=self.device).repeat(B, T))
            self.cached_first[(B, T)] = first
        
        return self.policy(
            obs=obs, 
            first=first, 
            state_in=state_in, 
            ice_latent=ice_latent, 
            **kwargs
        )
    
    @functools.lru_cache(maxsize=None)
    def encode_video(
        self, ref_video: str, ref_mask: float, resolution: Tuple[int, int], 
    ) -> torch.Tensor:
        input_mask = torch.zeros(self.timesteps, device=self.device)
        if ref_mask < 1: 
            one_idx = torch.arange(0, self.timesteps, int(1 / (1-ref_mask)), device=self.device)
            input_mask[one_idx] = 1

        frames = []
        with av.open(ref_video, "r") as container:
            for fid, frame in enumerate(container.decode(video=0)):
                resized_frame = cv2.resize(
                    frame.to_ndarray(format="rgb24"), 
                    (resolution[0], resolution[1]), interpolation=cv2.INTER_LINEAR
                )
                frames.append(resized_frame)

        segment = torch.stack(
            [torch.from_numpy(frame).to(self.device) for frame in frames[:self.timesteps]], dim=0
        ).unsqueeze(0)

        conditions = self.policy.encode_condition(img=segment, infer=True, input_mask=input_mask)
        ce_latent = conditions['ce_latent'].squeeze(0)
        print(
            "==============================================================================================================\n"
            f"[->] Reference video is from: {ref_video};\n"
            f"[->] Number of frames: {len(frames)}, only use the prefix of length {self.timesteps}. \n"
            f"[->] Shape of latent instruction: {ce_latent.shape} | mean: {ce_latent.mean().item(): .3f} | std: {ce_latent.std(): .3f}\n"
            "==============================================================================================================\n"
        )
        return ce_latent

    def load_input_condition(self, obs_conf: Dict, resolution: Tuple[int, int]) -> torch.Tensor:
        '''Load the input condition specified by the obs_conf. '''
        # pdb.set_trace()
        assert 'ref_video' in obs_conf, 'ref_video should be specified in obs_conf. '
        num = len(obs_conf['ref_video'])
        ice_latent = []
        for i in range(num):
            ref_video = obs_conf['ref_video'][i][0]
            if 'ref_mask' in obs_conf:
                ref_mask = obs_conf['ref_mask'][i][0]
            else:
                ref_mask = 0.0
            ce_latent = self.encode_video(ref_video=ref_video, ref_mask=float(ref_mask), resolution=resolution)
            ice_latent.append(ce_latent)
        return torch.stack(ice_latent, dim=0)
    
    @property
    def action_head(self) -> ActionHead:
        return self.policy.pi_head

    @property
    def value_head(self) -> torch.nn.Module:
        return self.policy.value_head
    
    def initial_state(self, batch_size: Optional[int] = None) -> List[torch.Tensor]:
        if batch_size is None:
            return [t.squeeze(0).to(self.device) for t in self.policy.initial_state(1)]
        else:
            if batch_size not in self.cached_init_states:
                self.cached_init_states[batch_size] = [t.to(self.device) for t in self.policy.initial_state(batch_size)]
            return self.cached_init_states[batch_size]

    def forward(self, 
                obs: Dict[str, Any], 
                state_in: Optional[List[torch.Tensor]],
                first: Optional[torch.Tensor] = None,
                **kwargs
    ) -> Tuple[Dict[str, torch.Tensor], List[torch.Tensor], Dict[str, Any]]:
        forward_result, state_out, latents = self.wrapped_forward(obs=obs, state_in=state_in, first=first, **kwargs)
        return forward_result, state_out, latents
    

if __name__ == "__main__":
    from mcu.stark_tech.env_interface import MinecraftWrapper
    e = MinecraftWrapper('diverses/collect_grass', prev_action_obs=True)
    o, _ = e.reset()
    agent = ConditionedAgent(obs_space=e.observation_space, action_space=e.action_space, policy_config='groot_eff_1x')
    action, state = agent.get_action(o, input_shape="*")
    o, _, _, _, _ = e.step(action)
    action, state = agent.get_action(o, state_in=state, input_shape="*")