from typing import Callable, Dict, Optional, Tuple
import hydra
from hydra import compose, initialize
from omegaconf import DictConfig
import torch
import av
from pathlib import Path
import numpy as np

from mcu.steveI.steveI_lib.utils.mineclip_agent_env_utils import make_agent
from mcu.steveI.steveI_lib.utils.embed_utils import get_prior_embed
from mcu.steveI.steveI_lib.config import MINECLIP_CONFIG, PRIOR_INFO
import mcu.steveI.steveI_lib.mineclip_code.load_mineclip as load_mineclip
from mcu.steveI.steveI_lib.data.text_alignment.vae import load_vae_model
from mcu.assembly.marks import MarkBase
from mcu.stark_tech.env_interface import MinecraftWrapper

# RELATIVE_POLICY_CONFIG_DIR = "./configs"

class SteveI(MarkBase):
    
    def __init__(self, env, config_path, policy_cfg="policy", device="cuda", **kwargs):
        self.env = env
        self.kwargs = kwargs
        if isinstance(policy_cfg, str):
            hydra.core.global_hydra.GlobalHydra.instance().clear()
            config_path = Path(config_path) / f"{policy_cfg}.yaml"
            initialize(config_path=str(config_path.parent), version_base='1.3')
            self.policy_cfg = compose(config_name=config_path.stem)
        elif isinstance(policy_cfg, Dict):
            self.policy_cfg = DictConfig(policy_cfg)
        elif isinstance(policy_cfg, DictConfig):action_space
            self.policy_cfg = policy_cfg
        else:
            raise ValueError("policy_cfg must be a string or a dict")
        
        self.cond_scale = self.policy_cfg.text_cond_scale
        mineclip_config = MINECLIP_CONFIG
        mineclip_config['ckpt']['path'] = self.policy_cfg.mineclip_weights

        self.device = device
        self.mineclip = load_mineclip.load(mineclip_config, device=self.device)
        prior_info = PRIOR_INFO
        prior_info['model_path'] = self.policy_cfg.prior_weights
        self.prior = load_vae_model(prior_info, device=self.device)
        self.agent = make_agent(self.policy_cfg.in_model, self.policy_cfg.in_weights, cond_scale=self.cond_scale, device=self.device)
        self.agent.reset(cond_scale=self.cond_scale)

    def reset(self):
        super().reset()
        self.agent.reset(cond_scale=self.cond_scale)
    
    def do(
        self, 
        condition: str = '', 
        timeout: int = 500, 
        target_reward: float = 1., 
        monitor_fn: Optional[Callable] = None,
        **kwargs, 
    ) -> Tuple[bool, Dict]:
        prompt_embed = get_prior_embed(condition, self.mineclip, self.prior, device=self.device)
        
        self.reset()
        self.env.manual_set_text(condition)

        self.obs, reward, terminated, truncated, self.info = self.env.step(self.env.noop_action())
        time_step = 0
        episode_reward = 0

        while (
            not terminated 
            and not truncated
            and time_step < timeout
        ):
            # print("in the steve loop.")
            with torch.cuda.amp.autocast():
                minerl_action = self.agent.get_action(self.obs, prompt_embed)
            
            # print('minerl_actions:', minerl_action)
            # print('minerl_actions type:', minerl_action['camera'].dtype)
            
            masked_minerl_actions = minerl_action
            masked_minerl_actions['hotbar.1'] = np.array([0])
            masked_minerl_actions['hotbar.2'] = np.array([0])
            masked_minerl_actions['hotbar.3'] = np.array([0])
            masked_minerl_actions['hotbar.4'] = np.array([0])
            masked_minerl_actions['hotbar.5'] = np.array([0])
            masked_minerl_actions['hotbar.6'] = np.array([0])
            masked_minerl_actions['hotbar.7'] = np.array([0])
            masked_minerl_actions['hotbar.8'] = np.array([0])
            masked_minerl_actions['hotbar.9'] = np.array([0])
            masked_minerl_actions['inventory'] = np.array([0])

            # print('masked_minerl_actions:', masked_minerl_actions)
            
            # self.obs, self.reward, terminated, truncated, self.info = self.env.step(minerl_action)
            self.obs, self.reward, terminated, truncated, self.info = self.env.step(masked_minerl_actions)
            self.record_step()
            if monitor_fn is not None:
                monitor_result = monitor_fn(self.info)
                if monitor_result[0]:
                    return monitor_result
            episode_reward += self.reward
            time_step += 1
            # self.record_step()
            if episode_reward >= target_reward:
                return True, {'success': True, 'terminated': False}
        if terminated:
            return False, {'reason': "environment reset.", "terminated": True}
        else:
            return False, {'reason': "reach goal maximum steps.", "terminated": False}

    def make_traj_video(self, file_name = 'dummy'):
        container = av.open(f'{file_name}.mp4', mode='w', format='mp4')
        stream = container.add_stream('h264', rate=20)
        stream.width = 640 
        stream.height = 360
        stream.pix_fmt = 'yuv420p'
        # for frame in self.record_frames:
        for info in self.record_infos:
            frame = info['pov']
            frame = av.VideoFrame.from_ndarray(frame, format='rgb24')
            for packet in stream.encode(frame):
                container.mux(packet)
        for packet in stream.encode():
            container.mux(packet)
        container.close()

if __name__ == '__main__':
    env = MinecraftWrapper('craft')
    env.reset()
    
    steveI = SteveI(env=env)
    steveI.reset()
    '''
    After doing a series of crafting/smelting, you need to close inventory/crafting_table/furnace
    '''
    # crafting
    result, error_message = steveI.do('craft oak planks')
    print(result, error_message)
    # result, error_message = steveI.do('craft stick')
    # print(result, error_message)
    # result, error_message = steveI.do('craft crafting table')
    # print(result, error_message)
    # result, error_message = steveI.do('craft furnace')
    # print(result, error_message)
    # result, error_message = steveI.do('craft wooden pickaxe')
    # print(result, error_message)   

    # # smelting
    # result, error_message = steveI.do('smelt charcoal')
    # print(result, error_message)
    # result, error_message = steveI.do('smelt baked_potato')
    # print(result, error_message)

    # # crafting
    # result, error_message = steveI.do('craft oak_planks')
    # print(result, error_message)
    # result, error_message = steveI.do('craft wooden pickaxe')
    # print(result, error_message)
    # result, error_message = steveI.do('craft stick')
    # print(result, error_message)

    steveI.make_traj_video()