from typing import Any, Dict, Generator, List, Optional, Union
from dataclasses import dataclass 
from collections import deque
import sys
import os

from gymnasium.core import Env
import gymnasium as gym
from gymnasium import Wrapper, ActionWrapper

import numpy as np
from tqdm import tqdm
import cv2
import matplotlib.pyplot as plt

import torch as th

from stable_baselines3.common.utils import obs_as_tensor

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from model.reward import Reward
from model.vae.encoder import Encoder


class CustomReward(Wrapper):
    def __init__(self, env, reward_net: Reward, encoder: Encoder | None = None, **kwargs):
        super().__init__(env, **kwargs)
        self.reward_net = reward_net
        self.encoder = encoder

        self.obs = None
        self.irl_reward = 0

    def step(self, action):
        next_obs, true_reward, terminated, truncated, info = super().step(action)

        with th.no_grad():
            reward, _ = self.reward_net.forward(obs_as_tensor(self.obs, self.encoder.device).float()[None, :], 
                                                obs_as_tensor(next_obs, self.encoder.device).float()[None, :], 
                                                self.encoder)
        reward = reward.item()

        self.obs = next_obs
        self.irl_reward += reward
        if 'episode' in info:
            info['episode']['irl_r'] = self.irl_reward

        return next_obs, reward, terminated, truncated, info

    def reset(self, **kwargs):
        obs, info = super().reset(**kwargs)
        self.obs = obs

        self.irl_reward = 0

        return obs, info
    

class DisabledHalfCheetah(CustomReward):
    def __init__(self, env: Env, reward_net: Reward, joints_status: List[int], encoder: Encoder | None = None):
        super().__init__(env, reward_net, encoder)
        self.joints_status = joints_status
        self.non_disabled_joints_total = np.array(joints_status).sum()
        self.action_space = gym.spaces.Box(low = -np.ones(self.non_disabled_joints_total, dtype=np.float32), 
                                        high = np.ones(self.non_disabled_joints_total, dtype=np.float32))

    def step(self, act):
        result = np.zeros(6)
        flag = 0
        for i in range(6):
            if self.joints_status[i] == 1:
                result[i] = act[flag]
                flag += 1

        next_obs, reward, terminated, truncated, info = super().step(result)
        info['converted_action'] = result
        return next_obs, reward, terminated, truncated, info


class DisabledHalfCheetahOnly(Wrapper):
    def __init__(self, env: Env, joints_status: List[int]):
        super().__init__(env)
        self.joints_status = joints_status
        self.non_disabled_joints_total = np.array(joints_status).sum()
        self.action_space = gym.spaces.Box(low = -np.ones(self.non_disabled_joints_total, dtype=np.float32), 
                                        high = np.ones(self.non_disabled_joints_total, dtype=np.float32))

    def step(self, act):
        result = np.zeros(6)
        flag = 0
        for i in range(6):
            if self.joints_status[i] == 1:
                result[i] = act[flag]
                flag += 1

        next_obs, reward, terminated, truncated, info = super().step(result)
        info['converted_action'] = result
        return next_obs, reward, terminated, truncated, info
    

class DisabledAnt(CustomReward):
    def __init__(self, env: Env, reward_net: Reward, joints_status: List[int], encoder: Encoder | None = None):
        super().__init__(env, reward_net, encoder)
        self.joints_status = joints_status
        self.non_disabled_joints_total = np.array(joints_status).sum()
        self.action_space = gym.spaces.Box(low = -np.ones(self.non_disabled_joints_total, dtype=np.float32), 
                                        high = np.ones(self.non_disabled_joints_total, dtype=np.float32))

    def step(self, act):
        result = np.zeros(8)
        flag = 0
        for i in range(8):
            if self.joints_status[i] == 1:
                result[i] = act[flag]
                flag += 1

        next_obs, reward, terminated, truncated, info = super().step(result)
        info['converted_action'] = result
        return next_obs, reward, terminated, truncated, info


class DisabledAntOnly(Wrapper):
    def __init__(self, env: Env, joints_status: List[int]):
        super().__init__(env)
        self.joints_status = joints_status
        self.non_disabled_joints_total = np.array(joints_status).sum()
        self.action_space = gym.spaces.Box(low = -np.ones(self.non_disabled_joints_total, dtype=np.float32), 
                                        high = np.ones(self.non_disabled_joints_total, dtype=np.float32))

    def step(self, act):
        result = np.zeros(8)
        flag = 0
        for i in range(8):
            if self.joints_status[i] == 1:
                result[i] = act[flag]
                flag += 1

        next_obs, reward, terminated, truncated, info = super().step(result)
        info['converted_action'] = result
        return next_obs, reward, terminated, truncated, info