from __future__ import annotations

import gym
import gym.spaces as spaces
import cv2
import numpy as np
import torch
from collections import deque, OrderedDict
import copy
from typing import Any, Dict, Optional, Tuple

import minedojo
from mineclip import MineCLIP


class HealthReward:

    def __init__(self, scale=0.1):
        self.scale = scale
        self.previous = None

    def reset(self, obs):
        self.previous = obs["life_stats"]["life"][0]

    def __call__(self, obs):
        health = obs["life_stats"]["life"][0]
        reward = self.scale * (health - self.previous)
        self.previous = health
        return reward

class ValidAttackReward:

    def __init__(self, scale=0.1):
        self.scale = scale
        self._weapon_durability_deque = deque(maxlen=2)

    def reset(self, obs):
        self._weapon_durability_deque.clear()
        self._weapon_durability_deque.append(obs["inventory"]["cur_durability"][0])

    def __call__(self, obs):
        self._weapon_durability_deque.append(obs["inventory"]["cur_durability"][0])
        # 剑是一种挖掘工具，每次成功挖掘了任何硬度不为0的方块会消耗2耐久度。
        # 剑每成功进行一次近战攻击会减少1点耐久度。
        valid_attack = (
            self._weapon_durability_deque[0] - self._weapon_durability_deque[1] - 1.0
        ) < 1e-3
        reward = self.scale * valid_attack
        return reward

class MineCLIPRewardWrapper(gym.Wrapper):
    """
    A Gym wrapper that adds dense rewards from MineCLIP to any MineDojo environment.
    This wrapper computes a dense reward based on the similarity between the agent's
    recent video frames and a given language prompt using the MineCLIP model.
    """
    def __init__(
        self,
        env: gym.Env,
        prompt: str,
        model: MineCLIP,
        device: Optional[torch.device] = None,
        image_size=(160, 256),  # As per MineCLIP input
        image_downsample: int=1,
        pitch_limit=(-60, 60),
        frame_skip: int = 1, 
        repeat: int = 1,
        max_episode_steps: int=1000,
        mineclip_reward: float = 1.0,
        success_reward: float=100,
        health_reward: float=0.1,
        valid_attack_reward: float=0.01,
    ):
        super().__init__(env)
        if device is None:
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.device = device
        self.prompt = prompt
        self.frame_skip = max(1, frame_skip)
        self.mineclip_reward = mineclip_reward
        self.success_reward = success_reward

        self.model = model

        if self.model:
            # Pre-encode the text prompt for efficiency
            self.text_feats = self.model.encode_text([self.prompt])

        # Buffer for recent frames: MineCLIP expects (B, 16, 3, H, W)
        self.frame_buffer = deque(maxlen=16)  # RGB frames
        self.step_count = 0
        self.max_episode_steps = max_episode_steps
        
        self._repeat = max(1, repeat)
        self._noop_action = self.action_space.no_op()
        # self.action_space = spaces.Dict(
        #     {
        #         'forward_and_backward': spaces.Discrete(3), # 0: noop, 1: forward, 2: back
        #         'left_and_right': spaces.Discrete(3), # 0: noop, 1: left, 2: right
        #         'jump_sneak_sprint': spaces.Discrete(4), # 0: noop, 1: jump, 2: sneak, 3: sprint
        #         'camera_delta_pitch': spaces.Discrete(25), # 0: -180 degree, 24: 180 degree
        #         'camera_delta_yaw': spaces.Discrete(25), # 0: -180 degree, 24: 180 degree
        #         'functional_actions': spaces.Discrete(8), # 0: noop, 1: use, 2: drop, 3: attack, 4: craft, 5: equip, 6: place, 7: destroy
        #         'craft_argument': spaces.Discrete(244), # All possible items to be crafted
        #         'equip_place_destroy_argument': spaces.Discrete(36), # Inventory slot indices
        #     }
        # )
        self.action_space = spaces.Dict(
            {
                'forward_and_backward': spaces.Discrete(3), # 0: noop, 1: forward, 2: back
                'left_and_right': spaces.Discrete(3), # 0: noop, 1: left, 2: right
                'jump_sneak_sprint': spaces.Discrete(4), # 0: noop, 1: jump, 2: sneak, 3: sprint
                'camera_delta_pitch': spaces.Discrete(3), # 0: -15 degree, 1: 0 degree, 2: 15 degree
                'camera_delta_yaw': spaces.Discrete(3), # 0: -15 degree, 1: 0 degree, 2: 15 degree
                'functional_actions': spaces.Discrete(3), #  0: noop, 1: use, 2: attack
            }
        )
        self.functional_actions_map = {
            0: 0,  # noop
            1: 1,  # use
            2: 3,  # attack
        }
        self.pitch_limit = pitch_limit

        H, W = image_size[0], image_size[1]
        self.image_size = (H // image_downsample, W // image_downsample, 3)
        self.observation_space = spaces.Dict({
                **self.observation_space.spaces,
                'rgb': spaces.Box(low=0, high=255, shape=self.image_size, dtype=np.uint8),
            }
        )
        self.rewards = {
            "health": HealthReward(health_reward),
            "valid_attack": ValidAttackReward(valid_attack_reward),
        }


    def reset(
        self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None
    ) -> Tuple[np.ndarray, Dict[str, Any]]:
        self.frame_buffer.clear()

        self.step_count = 0
        obs = self.env.reset(
            # seed=seed, options=options
        )
        for _, rew in self.rewards.items():
            rew.reset(obs)
        # Add initial frame (assuming obs['rgb'] is the egocentric RGB frame)
        self._add_frame(obs['rgb'])
        obs['rgb'] = cv2.resize(obs['rgb'].transpose(1, 2, 0), self.image_size[:2][::-1])
        obs.update({
            'log/success_reward': 0.,
            'log/mineclip_reward': 0.,
            'log/success': 0.,
            **{f"log/{k}_reward": 0. for k, _ in self.rewards.items()}
        })

        self._pitch = obs['location_stats']['pitch']
        return obs

    def _action(self, action):
        act = self._noop_action.copy()
        act[0] = action['forward_and_backward']
        act[1] = action['left_and_right']
        act[2] = action['jump_sneak_sprint']

        camera_delta_pitch = action['camera_delta_pitch'] - 1
        if self._pitch <= self.pitch_limit[0]:
            camera_delta_pitch = max(0, camera_delta_pitch)
        if self._pitch >= self.pitch_limit[1]:
            camera_delta_pitch = min(0, camera_delta_pitch)
        act[3] = camera_delta_pitch + 12

        act[4] = action['camera_delta_yaw'] + 11
        act[5] = self.functional_actions_map[action['functional_actions']]
        return act

    def step(
        self, action: Any
    ) -> Tuple[np.ndarray, float, bool, bool, Dict[str, Any]]:
        
        action = self._action(action)
        following = action.copy()
        following[3:5] = 0
        
        for act in [action] + ([following] * (self._repeat - 1)):
            obs, reward, done, info = self.env.step(act)
            if done or obs["life_stats"]["life"][0] <= 1e-3 \
                    or obs["inventory"]["cur_durability"][0] <= 0 \
                    or "error" in info \
                    or self.step_count >= self.max_episode_steps:
                done = True
                break

        
        self._add_frame(obs['rgb'])
        # (C, H, W) -> (H, W, C)
        obs['rgb'] = cv2.resize(obs['rgb'].transpose(1, 2, 0), self.image_size[:2][::-1])
        self._pitch = obs['location_stats']['pitch']

        self.step_count += 1

        # Compute dense reward only every frame_skip steps or if buffer is full
        dense_reward = 0.0
        if self.model and len(self.frame_buffer) == 16 and self.step_count % self.frame_skip == 0:
            dense_reward = self._compute_mineclip_reward()

        reward_shaping = {f"log/{k}_reward": fn(obs) for k, fn in self.rewards.items()}

        # Total reward: original sparse reward + scaled dense reward
        total_reward = (
            reward * self.success_reward
            + dense_reward * self.mineclip_reward
            + sum([rew for _, rew in reward_shaping.items()])
        )

        obs.update({
            'log/success_reward': reward * self.success_reward,
            'log/mineclip_reward': dense_reward * self.mineclip_reward,
            'log/success': float(self.env.is_successful),
            **reward_shaping
        })

        return obs, total_reward, done, info

    def _add_frame(self, frame: np.ndarray):
        frame = torch.from_numpy(frame.copy()).to(self.device)  # (3, H, W)
        self.frame_buffer.append(frame)

    def _compute_mineclip_reward(self) -> float:
        # Stack frames into video: (1, 16, 3, H, W)
        video = torch.stack(list(self.frame_buffer), dim=0).unsqueeze(0)  # (1, 16, 3, H, W)
        video = video.float() / 255.0  # Normalize to [0, 1] if needed (MineCLIP expects 0-255 int?)

        # Compute video features
        with torch.no_grad():
            video_feats = self.model.encode_video(video)

            # Compute similarity (logits_per_video)
            logits_per_video, _ = self.model.forward_reward_head(
                video_feats, text_tokens=self.text_feats
            )

        # Use the similarity score as dense reward (first batch item)
        return logits_per_video[0, 0].item()  # Scalar value
    


# Example usage for creating environments for all MineDojo tasks
def create_mineclip_env(task_name: str, wrapper_kwargs: dict, **kwargs) -> gym.Env:
    """
    Factory function to create a MineDojo environment for a given task,
    wrapped with MineCLIP dense rewards.

    Args:
        task_name: Name of the MineDojo task (e.g., "harvest_milk", "combat_spider").
        prompt: Language prompt for MineCLIP (should match the task goal).
        **kwargs: Additional kwargs for minedojo.make() and wrapper.

    Returns:
        Gym environment with MineCLIP dense rewards.
    """

    # Create base MineDojo env
    base_env = minedojo.make(
        task_id=task_name,
        image_size=(160, 256),  # As per MineCLIP input
        **kwargs
    )
    print(f"Created env for {task_name}.")
    prompt = """
Task overview: {}.
Step-by-step guidance:
{}
""".format(base_env.task_prompt, base_env.task_guidance)
    print(prompt)

    # Wrap with MineCLIP reward
    env = MineCLIPRewardWrapper(
        env=base_env,
        prompt=prompt,
        image_size=(160, 256),  # As per MineCLIP input
        image_downsample=2,
        **wrapper_kwargs
    )
    return env

