import copy
import pathlib

import gymnasium as gym
import numpy as np
import torch
from metadrive.engine.logger import get_logger
from metadrive.policy.env_input_policy import EnvInputPolicy

from viqs.metadrive.env.human_in_the_loop_env import HumanInTheLoopEnv
from viqs.sb3.common.monitor import Monitor
from viqs.sb3.sac import SAC
from viqs.sb3.sac.policies import SACPolicy

FOLDER_PATH = pathlib.Path(__file__).parent
logger = get_logger()

import random
import numpy as np
import torch


def get_sac_expert(expert_level='270'):
    logger.info("expert env start")
    train_env = HumanInTheLoopEnv(config={'manual_control': False, "use_render": False})

    if expert_level=='190':
        ckpt = FOLDER_PATH / "sac_190.zip"
    elif expert_level=='270':
        ckpt = FOLDER_PATH / "sac_270.zip"
    elif expert_level=='350':
        ckpt = FOLDER_PATH / "sac_350.zip"
    model = SAC.load(ckpt,
                     custom_objects={
                         "policy_class": SACPolicy,  # Using SACPolicy instead of TD3Policy
                         "learning_rate": lambda _: 1e-4,  # Learning rate
                         "gamma": lambda _: 0.99,  # Discount factor
                         "create_eval_env": False,  # Whether to create an evaluation environment
                         "env": train_env,
                         "ent_coef": 'auto',
                         "target_entropy": "auto",
                         "device": "cuda:0",
                     },
                     env=train_env)

    return model





class SACFakeHumanEnv(HumanInTheLoopEnv):
    last_takeover = None
    last_obs = None
    expert = None

    def __init__(self, config):
        self.extra_config = config
        super(SACFakeHumanEnv, self).__init__(config)
        if self.extra_config.get('expert_level'):
            self.expert = get_sac_expert(self.extra_config['expert_level'])
        else:
            self.expert = get_sac_expert()

        if self.config["use_discrete"]:
            self._num_bins = 13
            self._grid = np.linspace(-1, 1, self._num_bins)
            self._actions = np.array(np.meshgrid(self._grid, self._grid)).T.reshape(-1, 2)

    @property
    def action_space(self) -> gym.Space:
        if self.config["use_discrete"]:
            return gym.spaces.Discrete(self._num_bins ** 2)
        else:
            return super(SACFakeHumanEnv, self).action_space

    def default_config(self):
        config = super(SACFakeHumanEnv, self).default_config()
        config.update({
            "use_discrete": False,
            "disable_expert": False,
            "agent_policy": EnvInputPolicy,
            "free_level": 0.95,
            "manual_control": False,
            "use_render": False,
            "expert_deterministic": True,
        })
        config.update(self.extra_config)
        return config

    def continuous_to_discrete(self, a):
        distances = np.linalg.norm(self._actions - a, axis=1)
        discrete_index = np.argmin(distances)
        return discrete_index

    def discrete_to_continuous(self, a):
        continuous_action = self._actions[a.astype(int)]
        return continuous_action

    def step(self, actions):
        actions = np.asarray(actions).astype(np.float32)

        if self.config["use_discrete"]:
            actions = self.discrete_to_continuous(actions)

        self.agent_action = copy.copy(actions)
        self.last_takeover = self.takeover

        # Get expert action and determine whether to take over!
        if self.config["disable_expert"]:
            pass
        else:


            with torch.no_grad():
                last_obs, _ = self.expert.policy.obs_to_tensor(self.last_obs)
                distribution = self.expert.actor.get_distribution(last_obs)

                agent_action = torch.tensor(self.agent_action, dtype=torch.float32).unsqueeze(0).to(
                    last_obs.device)
                log_prob = distribution.my_log_prob(agent_action.to(last_obs.device))
                action_prob = log_prob.exp().detach().cpu().numpy()
                if self.config["expert_deterministic"]:
                    expert_action = distribution.mode()
                else:
                    expert_action = distribution.sample()

                current_q_expert_value = self.expert.critic(last_obs, expert_action)
                current_q_agent_value = self.expert.critic(last_obs, agent_action)

                current_q_expert_value = current_q_expert_value[0].detach().cpu().numpy()[0]
                current_q_agent_value = current_q_agent_value[0].detach().cpu().numpy()[0]

                expert_action = expert_action[0].detach().cpu().numpy()
                q_diff = current_q_expert_value - current_q_agent_value

            if self.config['use_action_diff']:
                a_free_level = self.config['a_free_level']
                if action_prob < 1 - a_free_level:
                    if self.config["use_discrete"]:
                        expert_action = self.continuous_to_discrete(expert_action)
                        expert_action = self.discrete_to_continuous(expert_action)
                    actions = expert_action
                    self.takeover = True
                else:
                    self.takeover = False

            else:
                q_free_level = self.config['q_free_level']  # * 1.01 ** (self.total_steps / 1000)

                if q_diff > q_free_level:
                    actions = expert_action
                    self.takeover = True
                else:
                    self.takeover = False

        o, r, d, i = super(HumanInTheLoopEnv, self).step(actions)
        self.takeover_recorder.append(self.takeover)
        self.total_steps += 1

        if self.config["use_render"]:
            super(HumanInTheLoopEnv, self).render(
                text={
                    "Total Cost": round(self.total_cost, 2),
                    "Takeover Cost": round(self.total_takeover_cost, 2),
                    "Takeover": "TAKEOVER" if self.takeover else "NO",
                    "Total Step": self.total_steps,
                    "Takeover Rate": "{:.2f}%".format(np.mean(np.array(self.takeover_recorder) * 100)),
                    "Pause": "Press E",
                }
            )

        assert i["takeover"] == self.takeover

        if self.config["use_discrete"]:
            i["raw_action"] = self.continuous_to_discrete(i["raw_action"])

        return o, r, d, i

    def _get_step_return(self, actions, engine_info):
        o, r, tm, tc, engine_info = super(HumanInTheLoopEnv, self)._get_step_return(actions, engine_info)
        self.last_obs = o
        d = tm or tc
        last_t = self.last_takeover
        engine_info["takeover_start"] = True if not last_t and self.takeover else False
        engine_info["takeover"] = self.takeover
        condition = engine_info["takeover_start"] if self.config["only_takeover_start_cost"] else self.takeover
        if not condition:
            engine_info["takeover_cost"] = 0
        else:
            cost = self.get_takeover_cost(engine_info)
            self.total_takeover_cost += cost
            engine_info["takeover_cost"] = cost
        engine_info["total_takeover_cost"] = self.total_takeover_cost
        engine_info["native_cost"] = engine_info["cost"]
        engine_info["episode_native_cost"] = self.episode_cost
        self.total_cost += engine_info["cost"]
        self.total_takeover_count += 1 if self.takeover else 0
        engine_info["total_takeover_count"] = self.total_takeover_count
        engine_info["total_cost"] = self.total_cost
        return o, r, d, engine_info

    def _get_reset_return(self, reset_info):
        o, info = super(HumanInTheLoopEnv, self)._get_reset_return(reset_info)
        self.last_obs = o
        self.last_takeover = False
        return o, info


if __name__ == "__main__":
    env = SACFakeHumanEnv(dict(a_free_level=0.95, q_free_level=2.0, use_action_diff='True', use_render=False))
    # env = Monitor(env=env, filename='./')
    env.reset()
    while True:
        _, _, done, info = env.step([0, 1])

        if done:
            print(
                f"reward:{info['episode']['ep_sum_step_reward_original']}  takeover_count{info['episode']['ep_total_takeover_count']}")

            env.reset()
