from collections import defaultdict
from typing import Any, Dict, List, Optional, Tuple, Type, Union

import numpy as np
import torch as th
from torch.nn import functional as F

from viqs.metadrive.model_utils.GAN import Discriminator
from viqs.metadrive.haco.PairedSAEBuffer import PairedSAEBuffer
from viqs.sb3.common.noise import ActionNoise
from viqs.sb3.common.type_aliases import GymEnv
from viqs.sb3.common.utils import polyak_update
from viqs.metadrive.haco.haco import HACOSAC
from viqs.metadrive.haco.haco_buffer import HACOReplayBuffer
from viqs.metadrive.haco.policies import HACOPolicy


class HACOOptimized(HACOSAC):


    def __init__(
        self,
        policy: Union[str, Type[HACOPolicy]],
        env: Union[GymEnv, str],
        learning_rate: dict = dict(actor=0.0, critic=0.0, entropy=0.0),
        buffer_size: int = 100,
        learning_starts: int = 100,
        batch_size: int = 256,
        tau: float = 0.005,
        gamma: float = 0.99,
        train_freq: Union[int, Tuple[int, str]] = 1,
        gradient_steps: int = 1,
        action_noise: Optional[ActionNoise] = None,
        replay_buffer_class: Optional[HACOReplayBuffer] = HACOReplayBuffer,
        replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
        optimize_memory_usage: bool = True,
        ent_coef: Union[str, float] = "auto",
        target_update_interval: int = 1,
        target_entropy: Union[str, float] = "auto",
        use_sde: bool = False,
        sde_sample_freq: int = -1,
        use_sde_at_warmup: bool = False,
        tensorboard_log: Optional[str] = None,
        create_eval_env: bool = False,
        policy_kwargs: Optional[Dict[str, Any]] = None,
        verbose: int = 0,
        seed: Optional[int] = None,
        device: Union[th.device, str] = "auto",
        _init_setup_model: bool = True,


        discriminator_lr: float = 1e-4,
        discriminator_update_freq: int = 1,
        d_train_steps: int = 10_000,

        cql_coefficient: float = 1.0,
        monitor_wrapper: bool = False,
    ):
        assert replay_buffer_class == HACOReplayBuffer

        super(HACOOptimized, self).__init__(
            policy,
            env,
            HACOPolicy,
            learning_rate,
            buffer_size,
            learning_starts,
            batch_size,
            tau,
            gamma,
            train_freq,
            gradient_steps,
            action_noise,
            replay_buffer_class=replay_buffer_class,
            replay_buffer_kwargs=replay_buffer_kwargs,
            policy_kwargs=policy_kwargs,
            tensorboard_log=tensorboard_log,
            verbose=verbose,
            device=device,
            create_eval_env=create_eval_env,
            seed=seed,
            use_sde=use_sde,
            sde_sample_freq=sde_sample_freq,
            use_sde_at_warmup=use_sde_at_warmup,
            optimize_memory_usage=optimize_memory_usage,
            monitor_wrapper=monitor_wrapper,
            _init_setup_model=False
        )

        self.cql_coefficient = float(cql_coefficient)
        self.discriminator_lr = float(discriminator_lr)
        self.discriminator_update_freq = int(discriminator_update_freq)
        self.d_train_steps = int(d_train_steps)
        self._disc_updates = 0


        obs_shape = self.observation_space.shape
        act_shape = self.action_space.shape
        self.paired_buf = PairedSAEBuffer(
            capacity=50_000,
            obs_shape=obs_shape,
            act_shape=act_shape,
            device=self.device,
        )
        self.k_pairs_per_step = max(8, self.batch_size // 2)
        self._setup_model()


    def _setup_model(self) -> None:
        super(HACOOptimized, self)._setup_model()


        self.human_data_buffer = self.replay_buffer

        state_dim = int(self.observation_space.shape[0])
        action_dim = int(self.action_space.shape[0])
        self.discriminator = Discriminator(state_dim, action_dim).to(self.device)
        self.discriminator_optimizer = th.optim.Adam(self.discriminator.parameters(), lr=self.discriminator_lr)

        self._create_aliases_optimized()

    def _create_aliases_optimized(self):

        self.cost_critic = self.policy.cost_critic
        self.cost_critic_target = self.policy.cost_critic_target

    def _feed_pairs_from_batch(self, replay_data) -> None:
        mask = replay_data.interventions
        if th.is_tensor(mask):
            mask = mask.squeeze(-1).bool()
        else:
            mask = th.as_tensor(mask).bool()
        if not mask.any():
            return

        def to_np(x):
            if th.is_tensor(x):
                return x.detach().cpu().numpy()
            return np.asarray(x)

        s_batch = to_np(replay_data.observations)[mask.cpu().numpy()]
        a_agent_batch = to_np(replay_data.actions_novice)[mask.cpu().numpy()]
        a_human_batch = to_np(replay_data.actions_behavior)[mask.cpu().numpy()]
        for s, aa, ah in zip(s_batch, a_agent_batch, a_human_batch):
            self.paired_buf.add_pair(s, aa, ah)


    def _update_discriminator_from_pairs(self, stat_recorder: Dict[str, list], k_pairs: int) -> None:
        if self.paired_buf.size() < k_pairs:
            stat_recorder["disc_paired_size"].append(float(self.paired_buf.size()))
            stat_recorder["disc_skip_pairs_short"].append(1.0)
            return

        states, actions, labels = self.paired_buf.sample_pairs(k_pairs)
        probs = self.discriminator(states, actions)
        loss = F.binary_cross_entropy(probs, labels)

        self.discriminator_optimizer.zero_grad()
        loss.backward()
        self.discriminator_optimizer.step()
        self._disc_updates += 1

        with th.no_grad():
            preds = (probs > 0.5).float()
            acc = (preds == labels).float().mean()

        stat_recorder["discriminator_loss"].append(float(loss.item()))
        stat_recorder["discriminator_accuracy"].append(float(acc.item()))
        stat_recorder["disc_batch_size"].append(float(labels.shape[0]))
        stat_recorder["disc_updates"].append(float(self._disc_updates))


    def train(self, gradient_steps: int, batch_size: int = 64) -> None:
        self.policy.set_training_mode(True)
        optimizers = {
            "actor": self.actor.optimizer,
            "critic": self.critic.optimizer,
        }
        if self.ent_coef_optimizer is not None:
            optimizers["entropy"] = self.ent_coef_optimizer
        self._update_learning_rate(optimizers)

        stat_recorder = defaultdict(list)

        for g_step in range(gradient_steps):

            replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)


            self._feed_pairs_from_batch(replay_data)

            global_update_idx = self._n_updates + g_step
            if (global_update_idx % self.discriminator_update_freq == 0) and (self._disc_updates < self.d_train_steps):
                self._update_discriminator_from_pairs(stat_recorder, k_pairs=self.k_pairs_per_step)

            if self.use_sde:
                self.actor.reset_noise()


            actions_pi, log_prob = self.actor.action_log_prob(replay_data.observations)
            log_prob = log_prob.reshape(-1, 1)

            ent_coef_loss = None
            if self.ent_coef_optimizer is not None:
                ent_coef = th.exp(self.log_ent_coef.detach())
                ent_coef_loss = -(self.log_ent_coef * (log_prob + self.target_entropy).detach()).mean()
                stat_recorder["ent_coef_loss"].append(ent_coef_loss.item())
            else:
                ent_coef = self.ent_coef_tensor

            stat_recorder["entropy"].append(-log_prob.mean().item())
            stat_recorder["ent_coef"].append(ent_coef.item())

            if ent_coef_loss is not None:
                self.ent_coef_optimizer.zero_grad()
                ent_coef_loss.backward()
                self.ent_coef_optimizer.step()


            with th.no_grad():
                next_actions, next_log_prob = self.actor.action_log_prob(replay_data.next_observations)

                next_q_values = th.cat(self.critic_target(replay_data.next_observations, next_actions), dim=1)
                next_q_values, _ = th.min(next_q_values, dim=1, keepdim=True)
                next_q_values = next_q_values - ent_coef * next_log_prob.reshape(-1, 1)
                target_q_values =  (1 - replay_data.dones) * self.gamma * next_q_values

                next_cost_q_values = th.cat(self.cost_critic_target(replay_data.next_observations, next_actions), dim=1)
                next_cost_q_values, _ = th.min(next_cost_q_values, dim=1, keepdim=True)
                target_cost_q_values = (
                    replay_data.intervention_costs + (1 - replay_data.dones) * self.gamma * next_cost_q_values
                )


            current_q_behavior_values = self.critic(replay_data.observations, replay_data.actions_behavior)
            current_q_novice_values = self.critic(replay_data.observations, replay_data.actions_novice)

            stat_recorder["q_value_behavior"].append(current_q_behavior_values[0].mean().item())
            stat_recorder["q_value_novice"].append(current_q_novice_values[0].mean().item())

            w_h = th.ones_like(replay_data.interventions, dtype=th.float32, device=self.device)
            w_n = th.ones_like(replay_data.interventions, dtype=th.float32, device=self.device)
            mask = replay_data.interventions
            mask_bool = mask.squeeze(-1).bool() if mask.ndim > 1 else mask.bool()

            if mask_bool.any():
                with th.no_grad():
                    int_states = replay_data.observations[mask_bool]
                    int_human_actions = replay_data.actions_behavior[mask_bool]
                    int_agent_actions = replay_data.actions_novice[mask_bool]

                    d_h = self.discriminator(int_states, int_human_actions).detach()
                    d_n = self.discriminator(int_states, int_agent_actions).detach()

                full_d_h = th.ones_like(replay_data.interventions, dtype=th.float32, device=self.device)
                full_d_n = th.ones_like(replay_data.interventions, dtype=th.float32, device=self.device)
                full_d_h[mask_bool] = d_h
                full_d_n[mask_bool] = d_n

                w_h = full_d_h
                w_n = 1.0 - full_d_n

            critic_loss_list = []
            for current_q_behavior, current_q_novice in zip(current_q_behavior_values, current_q_novice_values):
                l = 0.5 * F.mse_loss(current_q_behavior, target_q_values)
                cql_term = th.mean(
                    replay_data.interventions
                    * self.cql_coefficient
                    * (w_h * current_q_behavior - w_n * current_q_novice)
                )
                l -= cql_term
                critic_loss_list.append(l)
            critic_loss = sum(critic_loss_list)
            stat_recorder["critic_loss"].append(critic_loss.item())

            current_cost_q_values = self.cost_critic(replay_data.observations, replay_data.actions_behavior)
            cost_critic_loss = 0.5 * sum(
                [F.mse_loss(current_cost_q, target_cost_q_values) for current_cost_q in current_cost_q_values]
            )
            for i, v in enumerate(current_cost_q_values):
                stat_recorder[f"cost_q_value_{i}"].append(v.mean().item())
            stat_recorder["cost_critic_loss"].append(cost_critic_loss.item())
            merged_critic_loss = cost_critic_loss + critic_loss


            q_values_pi = th.cat(self.critic(replay_data.observations, actions_pi), dim=1)
            min_qf_pi, _ = th.min(q_values_pi, dim=1, keepdim=True)
            stat_recorder["q_value_min"].append(min_qf_pi.mean().item())

            cost_q_values_pi = th.cat(self.cost_critic(replay_data.observations, actions_pi), dim=1)
            min_cost_qf_pi, _ = th.min(cost_q_values_pi, dim=1, keepdim=True)
            stat_recorder["cost_q_value_min"].append(min_cost_qf_pi.mean().item())

            native_actor_loss = ent_coef * log_prob - min_qf_pi
            cost_actor_loss = min_cost_qf_pi
            actor_loss = (native_actor_loss ).mean()
            stat_recorder["actor_loss"].append(native_actor_loss.mean().item())
            stat_recorder["cost_actor_loss"].append(cost_actor_loss.mean().item())

            if self.policy_kwargs.get("share_features_extractor") == "critic":
                self._optimize_actor(actor_loss=actor_loss)
                self._optimize_critics(merged_critic_loss=merged_critic_loss)
            elif self.policy_kwargs.get("share_features_extractor") == "actor":
                raise ValueError("share_features_extractor='actor' not supported here.")
            else:
                self._optimize_actor(actor_loss=actor_loss)
                self._optimize_critics(merged_critic_loss=merged_critic_loss)


            if g_step % self.target_update_interval == 0:
                polyak_update(self.critic.parameters(), self.critic_target.parameters(), self.tau)
                polyak_update(self.cost_critic.parameters(), self.cost_critic_target.parameters(), self.tau)

        self._n_updates += gradient_steps
        self.logger.record("train/n_updates", self._n_updates)
        for key, values in stat_recorder.items():
            self.logger.record(f"train/{key}", np.mean(values))

    def _optimize_actor(self, actor_loss):
        self.actor.optimizer.zero_grad()
        actor_loss.backward()
        self.actor.optimizer.step()

    def _optimize_critics(self, merged_critic_loss):
        self.critic.optimizer.zero_grad()
        merged_critic_loss.backward()
        self.critic.optimizer.step()

    def _excluded_save_params(self) -> List[str]:
        return super(HACOOptimized, self)._excluded_save_params() + ["paired_buf"]

    def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
        state_dicts = ["policy", "actor.optimizer", "critic.optimizer"]
        if self.ent_coef_optimizer is not None:
            saved_pytorch_variables = ["log_ent_coef"]
            state_dicts.append("ent_coef_optimizer")
        else:
            saved_pytorch_variables = ["ent_coef_tensor"]

        return state_dicts, saved_pytorch_variables
