import io
import logging
import os
import pathlib
from collections import defaultdict
from typing import Any, Dict, List, Union, Optional

import numpy as np
import torch
import torch as th
from torch.nn import functional as F

from viqs.metadrive.model_utils.GAN import Discriminator
from viqs.sb3.common.buffers import ReplayBuffer
from viqs.sb3.common.save_util import load_from_pkl, save_to_pkl
from viqs.sb3.common.type_aliases import GymEnv, MaybeCallback
from viqs.sb3.common.utils import polyak_update
from viqs.metadrive.haco import HACOReplayBuffer, concat_samples
from viqs.sb3.td3.td3 import TD3

logger = logging.getLogger(__name__)


class VIQSTD3(TD3):
    def __init__(self, use_balance_sample=True, d_trian_steps=10000, q_value_bound=1.0, *args, **kwargs):
        self.d_trian_steps = d_trian_steps
        """Please find the hyperparameters from original TD3"""
        if "cql_coefficient" in kwargs:
            self.cql_coefficient = kwargs["cql_coefficient"]
            kwargs.pop("cql_coefficient")
        else:
            self.cql_coefficient = 1
        if "replay_buffer_class" not in kwargs:
            kwargs["replay_buffer_class"] = HACOReplayBuffer

        self.extra_config = {}

        for k in [
            "no_done_for_positive",
            "no_done_for_negative",
            "reward_0_for_positive",
            "reward_0_for_negative",
            "reward_n2_for_intervention",
            "reward_1_for_all",
            "use_weighted_reward",
            "remove_negative",
            "adaptive_batch_size",
            "add_bc_loss",
            "only_bc_loss",
            "with_human_proxy_value_loss",
            "with_agent_proxy_value_loss",

            "use_reward_for_policy",
            "use_reward_for_actor",
            "use_original_reward"
        ]:
            if k in kwargs:
                v = kwargs.pop(k)
                assert v in ["True", "False"]
                v = v == "True"
                self.extra_config[k] = v

        for k in ["agent_data_ratio", "bc_loss_weight", "decay"]:
            if k in kwargs:
                self.extra_config[k] = kwargs.pop(k)

        self.q_value_bound = q_value_bound
        self.use_balance_sample = use_balance_sample

        # 添加判别器相关参数
        self.discriminator_lr = 1e-4
        self.discriminator_update_freq = 1  # 每2个gradient steps更新一次判别器
        self.discriminator_steps = 0

        super(VIQSTD3, self).__init__(*args, **kwargs)

    def _setup_model(self) -> None:
        super(VIQSTD3, self)._setup_model()
        if self.use_balance_sample:
            self.human_data_buffer = HACOReplayBuffer(
                self.buffer_size,
                self.observation_space,
                self.action_space,
                self.device,
                n_envs=self.n_envs,
                optimize_memory_usage=self.optimize_memory_usage,
                **self.replay_buffer_kwargs
            )
        else:
            self.human_data_buffer = self.replay_buffer

        state_dim = self.observation_space.shape[0]
        action_dim = self.action_space.shape[0]
        self.discriminator = Discriminator(state_dim, action_dim).to(self.device)
        self.discriminator_optimizer = torch.optim.Adam(
            self.discriminator.parameters(), lr=self.discriminator_lr
        )

    def _update_discriminator(self, replay_data, stat_recorder):

        human_actions = replay_data.actions_behavior
        agent_actions = replay_data.actions_novice
        states = replay_data.observations

        human_labels = torch.ones((human_actions.shape[0], 1), device=self.device)
        agent_labels = torch.zeros((agent_actions.shape[0], 1), device=self.device)

        all_states = torch.cat([states, states], dim=0)
        all_actions = torch.cat([human_actions, agent_actions], dim=0)
        all_labels = torch.cat([human_labels, agent_labels], dim=0)

        indices = torch.randperm(all_states.shape[0])
        all_states = all_states[indices]
        all_actions = all_actions[indices]
        all_labels = all_labels[indices]

        predictions = self.discriminator(all_states, all_actions)

        loss = F.binary_cross_entropy(predictions, all_labels)

        self.discriminator_optimizer.zero_grad()
        loss.backward()
        self.discriminator_optimizer.step()

        accuracy = ((predictions > 0.5).float() == all_labels).float().mean()
        stat_recorder["discriminator_loss"] = loss.item()
        stat_recorder["discriminator_accuracy"] = accuracy.item()

        self.discriminator_steps += 1

    def train(self, gradient_steps: int, batch_size: int = 100) -> None:
        # Switch to train mode (this affects batch norm / dropout)
        self.policy.set_training_mode(True)
        decay = self.extra_config['decay'] ** (self._n_updates / 1000)

        # Update learning rate according to lr schedule
        self._update_learning_rate([self.actor.optimizer, self.critic.optimizer])

        with_human_proxy_value_loss = self.extra_config["with_human_proxy_value_loss"]
        with_agent_proxy_value_loss = self.extra_config["with_agent_proxy_value_loss"]

        use_reward_for_policy = self.extra_config["use_reward_for_policy"]
        use_reward_for_actor = self.extra_config['use_reward_for_actor']

        stat_recorder = defaultdict(list)
        stat_recorder['train/decay'] = decay
        should_concat = False

        if self.extra_config["adaptive_batch_size"]:
            agent_data_size = int(batch_size * self.extra_config["agent_data_ratio"])
            human_data_size = int(batch_size - agent_data_size)

            if self.replay_buffer.pos > agent_data_size and self.human_data_buffer.pos > human_data_size:
                replay_data_human = self.human_data_buffer.sample(human_data_size, env=self._vec_normalize_env)
                replay_data_agent = self.replay_buffer.sample(agent_data_size, env=self._vec_normalize_env)
                replay_data = concat_samples(replay_data_agent, replay_data_human)
            elif self.human_data_buffer.pos > batch_size:
                replay_data = self.human_data_buffer.sample(batch_size, env=self._vec_normalize_env)
            elif self.replay_buffer.pos > batch_size:
                replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)
            else:
                gradient_steps = 0
        else:
            if self.replay_buffer.pos > batch_size and self.human_data_buffer.pos > batch_size:
                replay_data_agent = self.replay_buffer.sample(int(batch_size / 2), env=self._vec_normalize_env)
                replay_data_human = self.human_data_buffer.sample(int(batch_size / 2), env=self._vec_normalize_env)
                replay_data = concat_samples(replay_data_agent, replay_data_human)
            elif self.human_data_buffer.pos > batch_size:
                replay_data = self.human_data_buffer.sample(batch_size, env=self._vec_normalize_env)
            elif self.replay_buffer.pos > batch_size:
                replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)
            else:
                gradient_steps = 0

        for step in range(gradient_steps):

            self._n_updates += 1

            if self.num_timesteps % self.discriminator_update_freq == 0 and self.num_timesteps < self.d_trian_steps:
                if self.human_data_buffer.pos > batch_size:
                    self._update_discriminator(
                        self.human_data_buffer.sample(int(batch_size), env=self._vec_normalize_env),
                        stat_recorder)
                else:

                    stat_recorder["discriminator_loss"] = 0.0
                    stat_recorder["discriminator_accuracy"] = 0.0

            with ((th.no_grad())):
                # Select action according to policy and add clipped noise
                noise = replay_data.actions_behavior.clone().data.normal_(0, self.target_policy_noise)
                noise = noise.clamp(-self.target_noise_clip, self.target_noise_clip)
                next_actions = (self.actor_target(replay_data.next_observations) + noise).clamp(-1, 1)

                # Compute the next Q-values: min over all critics targets
                next_q_values = th.cat(self.critic_target(replay_data.next_observations, next_actions), dim=1, )
                next_q_values, _ = th.min(next_q_values, dim=1, keepdim=True)

                if self.extra_config["use_original_reward"]:
                    target_q_values = replay_data.rewards + (
                            1 - replay_data.dones) * self.gamma * next_q_values
                else:
                    target_q_values = (1 - replay_data.dones) * self.gamma * next_q_values

            # Get current Q-values estimates for each critic network
            current_q_behavior_values = self.critic(replay_data.observations, replay_data.actions_behavior)
            current_q_novice_values = self.critic(replay_data.observations, replay_data.actions_novice)

            stat_recorder["q_value_behavior"].append(current_q_behavior_values[0].mean().item())
            stat_recorder["q_value_novice"].append(current_q_novice_values[0].mean().item())

            int_idx = torch.where(replay_data.interventions == 1)[0]

            if len(int_idx) > 0:
                int_states = replay_data.observations[int_idx]
                int_human_actions = replay_data.actions_behavior[int_idx]
                int_agent_actions = replay_data.actions_novice[int_idx]

                with torch.no_grad():
                    d_h = self.discriminator(int_states, int_human_actions)  # 512 1
                    d_n = self.discriminator(int_states, int_agent_actions)  # 512 1

                full_d_h = torch.ones_like(replay_data.interventions.float())  # 1024 1
                full_d_n = torch.ones_like(replay_data.interventions.float())  # 1024 1

                if self.extra_config["adaptive_batch_size"]:
                    if int_agent_actions.shape[0] == self.batch_size:
                        full_d_h = d_h
                        full_d_n = d_n
                    else:
                        full_d_h[agent_data_size:] = d_h
                        full_d_n[agent_data_size:] = d_n
                else:

                    if int_agent_actions.shape[0] == self.batch_size:
                        full_d_h = d_h
                        full_d_n = d_n
                    else:
                        full_d_h[int(batch_size / 2):] = d_h
                        full_d_n[int(batch_size / 2):] = d_n
                w_h = full_d_h
                w_n = 1 - full_d_n
            else:
                full_d_h = torch.ones_like(replay_data.interventions.float())
                full_d_n = torch.ones_like(replay_data.interventions.float())
                w_h = torch.ones_like(replay_data.interventions.float())
                w_n = torch.ones_like(replay_data.interventions.float())

            # Compute critic loss
            critic_loss = []
            for current_q_behavior, current_q_novice in zip(current_q_behavior_values, current_q_novice_values):
                l = F.mse_loss(current_q_behavior, target_q_values)

                if use_reward_for_policy:
                    target_h = (2 * full_d_h - 1)
                    target_n = (2 * full_d_n - 1)

                    if with_human_proxy_value_loss:
                        human_l = th.mean(
                            replay_data.interventions
                            * w_h
                            * self.cql_coefficient
                            * F.mse_loss(
                                current_q_behavior,
                                target_h,

                                reduction="none",
                            )
                        )
                        stat_recorder["train/human_l"] = human_l.item()
                    if with_agent_proxy_value_loss:
                        agent_l = th.mean(
                            replay_data.interventions
                            * w_n
                            * self.cql_coefficient
                            * F.mse_loss(
                                current_q_novice,

                                target_n,

                                reduction="none",
                            )
                        )
                        stat_recorder['train/agent_l'] = agent_l.item()
                    l += (human_l + agent_l) * decay


                else:
                    if with_human_proxy_value_loss:
                        human_l = th.mean(
                            replay_data.interventions
                            * self.cql_coefficient
                            * F.mse_loss(
                                current_q_behavior,
                                self.q_value_bound * th.ones_like(current_q_behavior),
                                reduction="none",
                            )
                        )
                        stat_recorder["train/human_l"] = human_l.item()
                    if with_agent_proxy_value_loss:
                        agent_l = th.mean(
                            replay_data.interventions
                            * self.cql_coefficient
                            * F.mse_loss(
                                current_q_novice,
                                -self.q_value_bound * th.ones_like(current_q_behavior),
                                reduction="none",
                            )
                        )
                        stat_recorder['train/agent_l'] = agent_l.item()
                    l += human_l + agent_l
                critic_loss.append(l)
            critic_loss = sum(critic_loss)

            # Optimize the critics
            self.critic.optimizer.zero_grad()
            critic_loss.backward()
            self.critic.optimizer.step()
            stat_recorder["train/critic_loss"] = critic_loss.item()

            mask = replay_data.interventions.flatten().float()
            # Delayed policy updates
            if self._n_updates % self.policy_delay == 0:
                # Compute actor loss
                new_action = self.actor(replay_data.observations)
                current_q_agent_values = self.critic(replay_data.observations, new_action)

                # BC loss on human data
                bc_loss = F.mse_loss(replay_data.actions_behavior, new_action, reduction="none")

                bc_loss = bc_loss.mean(axis=-1)
                if use_reward_for_actor:
                    # masked_bc_loss = (mask * bc_loss * normalized_rewards).sum() / (mask.sum() + 1e-5)
                    weight = full_d_h - full_d_n
                    masked_bc_loss = (mask * bc_loss * weight).sum() / (mask.sum() + 1e-5)
                else:
                    masked_bc_loss = (mask * bc_loss).sum() / (mask.sum() + 1e-5)

                if self.extra_config["only_bc_loss"]:
                    actor_loss = masked_bc_loss
                    # Critics will be completely useless.

                else:
                    actor_loss = -self.critic.q1_forward(replay_data.observations, new_action).mean()
                    if self.extra_config["add_bc_loss"]:
                        actor_loss += (masked_bc_loss * self.extra_config["bc_loss_weight"])

                # Optimize the actor
                self.actor.optimizer.zero_grad()
                actor_loss.backward()
                self.actor.optimizer.step()
                stat_recorder["train/actor_loss"] = actor_loss.item()
                stat_recorder["train/masked_bc_loss"] = masked_bc_loss.item()
                stat_recorder["train/bc_loss"] = bc_loss.mean().item()

                polyak_update(self.critic.parameters(), self.critic_target.parameters(), self.tau)
                polyak_update(self.actor.parameters(), self.actor_target.parameters(), self.tau)

        self.logger.record("train/n_updates", self._n_updates)
        for key, values in stat_recorder.items():
            self.logger.record("train/{}".format(key), np.mean(values))

    def _store_transition(
            self,
            replay_buffer: ReplayBuffer,
            buffer_action: np.ndarray,
            new_obs: Union[np.ndarray, Dict[str, np.ndarray]],
            reward: np.ndarray,
            dones: np.ndarray,
            infos: List[Dict[str, Any]],
    ) -> None:
        if infos[0]["takeover"] or infos[0]["takeover_start"]:
            replay_buffer = self.human_data_buffer
        super(VIQSTD3, self)._store_transition(
            replay_buffer, buffer_action, new_obs, reward, dones, infos
        )

    def save_replay_buffer(
            self,
            path_human: Union[str, pathlib.Path, io.BufferedIOBase],
            path_replay: Union[str, pathlib.Path, io.BufferedIOBase],
    ) -> None:
        save_to_pkl(path_human, self.human_data_buffer, self.verbose)
        super(VIQSTD3, self).save_replay_buffer(path_replay)

    def load_replay_buffer(
            self,
            path_human: Union[str, pathlib.Path, io.BufferedIOBase],
            path_replay: Union[str, pathlib.Path, io.BufferedIOBase],
            truncate_last_traj: bool = True,
    ) -> None:
        """
        Load a replay buffer from a pickle file.

        :param path: Path to the pickled replay buffer.
        :param truncate_last_traj: When using ``HerReplayBuffer`` with online sampling:
            If set to ``True``, we assume that the last trajectory in the replay buffer was finished
            (and truncate it).
            If set to ``False``, we assume that we continue the same trajectory (same episode).
        """
        self.human_data_buffer = load_from_pkl(path_human, self.verbose)
        assert isinstance(
            self.human_data_buffer, ReplayBuffer
        ), "The replay buffer must inherit from ReplayBuffer class"

        # Backward compatibility with SB3 < 2.1.0 replay buffer
        # Keep old behavior: do not handle timeout termination separately
        if not hasattr(
                self.human_data_buffer, "handle_timeout_termination"
        ):  # pragma: no cover
            self.human_data_buffer.handle_timeout_termination = False
            self.human_data_buffer.timeouts = np.zeros_like(self.replay_buffer.dones)
        super(VIQSTD3, self).load_replay_buffer(path_replay, truncate_last_traj)

    def learn(
            self,
            total_timesteps: int,
            callback: MaybeCallback = None,
            log_interval: int = 4,
            eval_env: Optional[GymEnv] = None,
            eval_freq: int = -1,
            n_eval_episodes: int = 5,
            tb_log_name: str = "run",
            eval_log_path: Optional[str] = None,
            reset_num_timesteps: bool = True,
            save_timesteps: int = 2000,
            buffer_save_timesteps: int = 2000,
            save_path_human: Union[str, pathlib.Path, io.BufferedIOBase] = "",
            save_path_replay: Union[str, pathlib.Path, io.BufferedIOBase] = "",
            save_buffer: bool = True,
            load_buffer: bool = False,
            load_path_human: Union[str, pathlib.Path, io.BufferedIOBase] = "",
            load_path_replay: Union[str, pathlib.Path, io.BufferedIOBase] = "",
            warmup: bool = False,
            warmup_steps: int = 5000,
    ) -> "OffPolicyAlgorithm":

        total_timesteps, callback = self._setup_learn(
            total_timesteps,
            eval_env,
            callback,
            eval_freq,
            n_eval_episodes,
            eval_log_path,
            reset_num_timesteps,
            tb_log_name,
        )
        if load_buffer:
            self.load_replay_buffer(load_path_human, load_path_replay)
        callback.on_training_start(locals(), globals())
        if warmup:
            assert load_buffer, "warmup is useful only when load buffer"
            print("Start warmup with steps: " + str(warmup_steps))
            self.train(batch_size=self.batch_size, gradient_steps=warmup_steps)

        while self.num_timesteps < total_timesteps:
            rollout = self.collect_rollouts(
                self.env,
                train_freq=self.train_freq,
                action_noise=self.action_noise,
                callback=callback,
                learning_starts=self.learning_starts,
                replay_buffer=self.replay_buffer,
                log_interval=log_interval,
            )

            if rollout.continue_training is False:
                break

            if self.num_timesteps > 0 and self.num_timesteps > self.learning_starts:
                # If no `gradient_steps` is specified,
                # do as many gradients steps as steps performed during the rollout
                gradient_steps = self.gradient_steps if self.gradient_steps >= 0 else rollout.episode_timesteps

                # Special case when the user passes `gradient_steps=0`
                if gradient_steps > 0:
                    self.train(batch_size=self.batch_size, gradient_steps=gradient_steps)

            if save_buffer and self.num_timesteps > 0 and self.num_timesteps % buffer_save_timesteps == 0:
                buffer_location_human = os.path.join(save_path_human,
                                                     "human_buffer_" + str(self.num_timesteps) + ".pkl")
                buffer_location_replay = os.path.join(save_path_replay,
                                                      "replay_buffer_" + str(self.num_timesteps) + ".pkl", )
                logger.info("Saving..." + str(buffer_location_human))
                logger.info("Saving..." + str(buffer_location_replay))
                self.save_replay_buffer(buffer_location_human, buffer_location_replay)

        callback.on_training_end()

        return self
