from typing import Any, Optional, Union, List, Type

from functools import partial
from types import SimpleNamespace

import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.swa_utils import AveragedModel
from torch.nn.utils.rnn import pad_sequence, unpad_sequence

import wandb

from algo.buffer import VecTrajectoryBuffer
from algo.net import ortho_init, TransitionModel
from algo.util import LinearSchedule, update_optimizer, tree_backup_target, MovingStatistics, entropy_from_log_policy, get_paired

LOGGING_STEPS = 1000
FEATURE_RANK_SAMPLES = 2048


class Policy(nn.Module):

    def __init__(
        self,
        env: gym.Env,
        torso: Type[nn.Module],
        z_dim: int = 32,
        lucky: bool = False,
    ) -> None:
        super().__init__()

        self.torso = torso(env.single_observation_space)
        self.n_actions = env.single_action_space.n

        fdim = self.torso.features_dim
        self.fdim = fdim

        self.adv_net = nn.Linear(fdim, self.n_actions)
        self.adv_net.apply(partial(ortho_init, gain=0.1))

        self.val_net = nn.Linear(fdim, 1)
        self.val_net.apply(partial(ortho_init, gain=1))

        self.pol_net = nn.Linear(fdim, self.n_actions)
        self.pol_net.apply(partial(ortho_init, gain=0.01))

        self.lucky = lucky
        if lucky:
            assert z_dim > 0
            self.z_dim = z_dim
            self.luck_val_net = nn.Linear(fdim, self.n_actions * z_dim)

    def forward(self, obs: torch.Tensor, norm_pol: Optional[Union[torch.Tensor, str]] = None):

        result = SimpleNamespace()

        latent = self.torso(obs)
        result.latent = latent

        adv_raw = self.adv_net(latent)
        result.adv_raw = adv_raw

        result.val = self.val_net(latent)

        logits = self.pol_net(latent.detach())
        result.logits = logits
        result.pol = F.softmax(logits, dim=1)
        result.logpol = F.log_softmax(logits, dim=1)

        if norm_pol == 'unif':
            result.adv = adv_raw - adv_raw.mean(dim=1, keepdim=True)
        else:
            if norm_pol is None:
                norm_pol = result.pol
            result.adv = adv_raw - (adv_raw * norm_pol.detach()).sum(dim=1, keepdim=True)

        result.q = result.adv + result.val

        if self.lucky:
            result.luck_val = self.luck_val_net(latent).view(-1, self.n_actions, self.z_dim)

        return result

    def feature_rank(self, obs: torch.Tensor, eps: float = 0.01):
        latent = self.torso(obs)
        return torch.linalg.matrix_rank(latent, atol=eps)


class Agent:

    def __init__(
        self,
        env: gym.Env,
        policy: nn.Module,
        learning_rate: LinearSchedule = LinearSchedule(2.5e-4),
        learning_rate_model: float = 2.5e-4,
        adam_eps: float = 1e-5,
        buffer_size: int = 100000,
        initial_steps: int = 2000,
        updates_per_step: float = 2,
        batch_size: int = 64,
        n_step: int = 16,
        z_dim: int = 16,
        gumbel_temperature: LinearSchedule = LinearSchedule(1.),
        quantizer: str = 'gumbel_soft',
        gamma: float = 0.99,
        target_update_steps: int = 100,
        target_update_tau: float = 0.99,
        target_bootstrap: bool = True,
        beta_kl: float = 1.,
        beta_entropy_model: float = 1e-4,
        beta_commit: float = 1.,
        backup: str = 'naive',
        max_grad_norm: float = 10.,
        logging: bool = False,
        device: Union[torch.device, str] = "cpu",
    ) -> None:
        super().__init__()

        self.env = env
        self.n_actions = env.single_action_space.n
        self.policy = policy
        self.buffer = VecTrajectoryBuffer(env, buffer_size, n_step, device)
        self.initial_steps = initial_steps
        _updates_per_step = updates_per_step * env.num_envs
        if _updates_per_step >= 1:
            self.update_period = 1
            self.updates_per_step = int(_updates_per_step)
        else:
            self.update_period = int(1 / _updates_per_step)
            self.updates_per_step = 1

        self.batch_size = batch_size
        self.n_step = n_step
        self.gamma = gamma
        self.beta_kl = beta_kl

        self.transition_model = None
        self.lucky = False
        if backup == 'naive':
            self.update = self._update_naive
        elif backup == 'tree':
            self.update = self._update_tree
        elif backup == 'dae':
            self.update = self._update_dae
        elif backup == 'offdae':
            self.lucky = True
            self.update = self._update_dae
            self.transition_model = TransitionModel(
                env,
                z_dim,
                quantizer=quantizer,
                beta_commit=beta_commit,
                beta_entropy=beta_entropy_model,
            ).to(device=device)
            self.transition_model.train()
            self.optimizer_model = torch.optim.Adam(
                self.transition_model.parameters(),
                lr=learning_rate_model,
                betas=(0.5, 0.9),
            )
            self.z_dim = z_dim
        else:
            raise NotImplementedError
        self.gumbel_temperature = gumbel_temperature

        self.learning_rate = learning_rate
        self.optimizer = torch.optim.Adam(policy.parameters(), lr=learning_rate.init, eps=adam_eps)

        self.max_grad_norm = max_grad_norm

        self.target_policy = AveragedModel(
            self.policy,
            avg_fn=lambda p0, p1, n: target_update_tau * p0 + (1-target_update_tau) * p1,
            use_buffers=True
        )
        self.target_policy.update_parameters(self.policy)
        self.target_policy.eval()
        self.target_update_steps = target_update_steps
        self.target_bootstrap = target_bootstrap

        self.device = device

        self.logging = logging

        self._init_discount_matrix()

        self._stats = MovingStatistics(dict(
            score=10,
            length=10,
            gnorm=100,
            loss=100,
            entropy=1000,
            loss_critic=100,
            loss_actor=100,
            loss_kl=100,
            adv_std=100,
            ))

        if self.lucky:
            self._stats.regist('luck', 100)
            self._stats.regist('loss_recon', 100)
            self._stats.regist('gnorm_model', 100)
            self._stats.regist('entropy_prior', 100)
            if quantizer == 'vq':
                self._stats.regist('loss_commit', 100)
                self._stats.regist('loss_prior', 100)
            elif 'gumbel' in quantizer or quantizer == 'exact':
                self._stats.regist('loss_kl_model', 100)
                self._stats.regist('entropy_posterior', 100)

        self.episode_stats = []

    def _init_discount_matrix(self):

        length = self.n_step

        self.discount_matrix = torch.tensor(
            [[0 if j < i else self.gamma ** (j - i) for j in range(length)] for i in range(length)],
            dtype=torch.float,
            device=self.device)

        self.discount_vector = self.gamma ** torch.arange(length, 0, -1, dtype=torch.float, device=self.device)
        self.discount_vector_2 = self.gamma ** torch.arange(length, dtype=torch.float, device=self.device)

    def _dcsum(self, x: torch.Tensor):
        assert len(x) <= len(self.discount_matrix)

        mat = self.discount_matrix[:len(x), :len(x)]
        return torch.matmul(mat, x)

    @torch.no_grad()
    def act(self, obs: torch.Tensor):

        self.policy.eval()
        result = self.policy(obs)
        self._stats['entropy'].append(entropy_from_log_policy(result.logpol))
        actions = torch.multinomial(result.pol, 1).flatten().cpu().numpy()

        return actions

    def optimizer_step(
        self,
        loss: torch.Tensor,
        loss_pol: torch.Tensor,
        loss_kl: torch.Tensor,
    ):
        self._stats['loss_critic'].append(loss.item())
        self._stats['loss_actor'].append(loss_pol.item())
        self._stats['loss_kl'].append(loss_kl.item())

        loss = (
            loss +
            loss_pol +
            self.beta_kl * loss_kl
        )
        self.optimizer.zero_grad(set_to_none=True)
        loss.backward()
        if self.max_grad_norm > 0:
            nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
        pars = [p for p in self.policy.parameters() if p.grad is not None]
        grad_norm = torch.norm(torch.stack([torch.norm(p.grad.detach()) for p in pars])).item()
        self.optimizer.step()

        self._stats['gnorm'].append(grad_norm)
        self._stats['loss'].append(loss.item())

    def _update_postprocessing(self):

        self.updates += 1
        if self.updates % self.target_update_steps == 0:
            self.target_policy.update_parameters(self.policy)

    def _update_tree(self, updates: int):

        for _update in range(updates):
            (obs, act, rew), dones, next_obs, splits = self.buffer.sample(self.batch_size)
            with torch.no_grad():
                rt = self.target_policy(obs)
                q_target, old_pol, old_logpol = rt.q, rt.pol, rt.logpol
                v_target_n = self.target_policy(next_obs).val

            result = self.policy(obs, norm_pol=old_pol)
            q, adv, pol, logpol = \
                result.q, result.adv, result.pol, result.logpol
            q_act = q.gather(-1, act[:, None]).squeeze(-1)

            q_target = tree_backup_target(
                    act,
                    rew,
                    q_target,
                    old_pol,
                    v_target_n,
                    dones,
                    splits,
                    self.gamma)

            loss = (q_act - q_target).square().mean()

            loss_kl = (pol * (logpol - old_logpol)).sum(dim=1).mean()
            _adv_std = (adv.square() * old_pol).sum(dim=1).mean().sqrt().detach()
            loss_pol = -(adv.detach() * pol).sum(dim=1).mean() / _adv_std
            self.optimizer_step(loss, loss_pol, loss_kl)

            self._update_postprocessing()

    def _update_naive(self, updates: int):

        for _update in range(updates):
            (obs, act, rew), dones, next_obs, splits = self.buffer.sample(self.batch_size)
            with torch.no_grad():
                rt = self.target_policy(obs)
                old_pol, old_logpol = rt.pol, rt.logpol
                q_target = self.target_policy(next_obs).q

            result = self.policy(obs, norm_pol=old_pol)
            q, adv = result.q, result.adv
            pol, logpol = result.pol, result.logpol
            _target_next = q_target.amax(dim=-1)
            q_act = q.gather(-1, act[:, None]).squeeze()

            discounted_target = torch.cat([
                self.discount_vector[-s:] * v * (1 - float(d))
                for d, s, v in zip(dones, splits, _target_next)])
            dcsum_rew = torch.cat(unpad_sequence(self._dcsum(pad_sequence(rew.split(splits))), splits))
            loss = (dcsum_rew + discounted_target - q_act).square().mean()

            loss_kl = (pol * (logpol - old_logpol)).sum(dim=1).mean()

            _adv_std = (adv.square() * old_pol).sum(dim=1).mean().sqrt().detach()
            loss_pol = -(adv.detach() * pol).sum(dim=1).mean() / _adv_std
            self.optimizer_step(loss, loss_pol, loss_kl)

            self._update_postprocessing()

    def _update_dae(self, updates: int):

        for _update in range(updates):
            (obs, act, rew), dones, next_obs, splits = self.buffer.sample(self.batch_size)
            obs_all = torch.cat([obs, next_obs])
            with torch.no_grad():
                _rt = self.target_policy(obs_all)
                old_pol, old_logpol = _rt.pol[:len(obs)], _rt.logpol[:len(obs)]

            result = self.policy(obs_all, norm_pol=_rt.pol)
            adv, val = result.adv[:len(obs)], result.val[:len(obs)].squeeze(-1)
            pol, logpol = result.pol[:len(obs)], result.logpol[:len(obs)]

            if self.target_bootstrap:
                val_next = _rt.val[len(obs):].squeeze(-1)
            else:
                val_next = result.val[len(obs):].squeeze(-1)

            if self.lucky:
                paired_obs = get_paired(obs, next_obs, splits)
                result_m = self.transition_model(*paired_obs, act, self._gumbel_temperature)
                self._stats['loss_recon'].append(result_m.loss_recon.item())
                self._stats['entropy_prior'].append(result_m.entropy_prior.item())

                if self.transition_model.quantizer == 'vq':
                    self._stats['loss_commit'].append(result_m.loss_commit.item())
                    self._stats['loss_prior'].append(result_m.loss_prior.item())
                else:
                    self._stats['loss_kl_model'].append(result_m.loss_kl.item())
                    self._stats['entropy_posterior'].append(result_m.entropy_posterior.item())

                self.optimizer_model.zero_grad(set_to_none=True)
                result_m.loss_model.backward()
                if self.max_grad_norm > 0:
                    nn.utils.clip_grad_norm_(self.transition_model.parameters(), self.max_grad_norm)
                pars = [p for p in self.transition_model.parameters() if p.grad is not None]
                grad_norm = torch.norm(torch.stack([torch.norm(p.grad.detach()) for p in pars])).item()
                self._stats['gnorm_model'].append(grad_norm)
                self.optimizer_model.step()

                posterior, prior = result_m.posterior, result_m.prior
                luck_val_act = result.luck_val[np.arange(len(obs)), act]
                lucks = (luck_val_act * (posterior - prior).detach()).sum(dim=-1)
                self._stats['luck'].append(lucks.square().mean().item())
            else:
                lucks = torch.zeros_like(rew)

            adv_act = adv.gather(-1, act[:, None]).squeeze(-1)
            delta = rew + self.gamma * lucks - adv_act
            loss_kl = (pol * (logpol - old_logpol)).sum(dim=1).mean()

            _adv_std = (adv.square() * old_pol).sum(dim=1).mean().sqrt().detach()
            self._stats['adv_std'].append(_adv_std.item())

            loss_pol = -(adv.detach() * pol).sum(dim=1).mean() / _adv_std

            discounted_target = torch.cat([
                self.discount_vector[-s:] * v * (1 - float(d))
                for d, s, v in zip(dones, splits, val_next)])

            dcsum_delta = torch.cat(unpad_sequence(self._dcsum(pad_sequence(delta.split(splits))), splits))
            loss = (dcsum_delta + discounted_target - val).square().mean()

            self.optimizer_step(loss, loss_pol, loss_kl)

            self._update_postprocessing()

    @torch.no_grad()
    def evaluate(self, env: gym.Env, episodes: int = 5, seed: Optional[int] = None):

        assert episodes == env.num_envs

        self.policy.eval()

        scores = []

        obs, _ = env.reset(seed=seed)
        dones = np.array([False] * episodes, dtype=bool)
        while len(scores) < episodes:
            action = self.act(torch.as_tensor(obs, dtype=torch.float, device=self.device))
            obs, _, _, _, info = env.step(action)

            if 'final_info' in info:
                mask = np.logical_and(np.invert(dones), info['_final_info'])
                for _info in info['final_info'][mask]:
                    if 'episode' in _info:
                        scores.append(_info['episode']['r'][0])
                dones = np.logical_or(dones, info['_final_info'])

        return np.mean(scores)

    def train(
        self,
        steps: int = 100000,
        env_eval: Optional[gym.Env] = None,
        steps_per_eval: int = -1,
        eval_episodes: int = 100,
        seed: Optional[int] = None,
    ):

        env, buf = self.env, self.buffer

        obs, _ = env.reset(seed=seed)

        self.updates = total_steps = recorded_steps = 0

        score_eval = None
        evaluate = env_eval is not None and steps_per_eval > 0 and eval_episodes > 0
        if evaluate:
            score_eval = self.evaluate(env_eval, eval_episodes)

        while total_steps < steps:

            update_optimizer(self.optimizer, 'lr', self.learning_rate.value(total_steps))
            self._gumbel_temperature = self.gumbel_temperature.value(total_steps)
            action = self.act(torch.as_tensor(obs, dtype=torch.float, device=self.device))

            prev_steps = total_steps
            total_steps += env.num_envs
            obs_next, rew, done, trunc, info = env.step(action)
            obs_cache = np.copy(obs_next)

            if 'final_observation' in info:
                mask = info['_final_observation']
                obs_next[mask] = np.stack(info['final_observation'][mask])

            buf.add(obs, action, rew, done, obs_next)
            obs = obs_cache

            if 'final_info' in info:
                for _info in info['final_info'][info['_final_info']]:
                    if 'episode' in _info:
                        _r, _l = _info['episode']['r'][0], _info['episode']['l'][0]
                        self._stats['score'].append(_r)
                        self._stats['length'].append(_l)
                        self.episode_stats.append((_r, _l))
                        recorded_steps += _l

            if evaluate and total_steps // steps_per_eval != prev_steps // steps_per_eval:
                score_eval = self.evaluate(env_eval, eval_episodes)

            if total_steps >= self.initial_steps and buf.trajectories:

                if total_steps // self.update_period != prev_steps // self.update_period:
                    self.policy.train()
                    self.update(self.updates_per_step)
                    self.policy.eval()

            if total_steps // LOGGING_STEPS != prev_steps // LOGGING_STEPS:
                if buf.trajectories:
                    obs_sample = buf.sample(FEATURE_RANK_SAMPLES)[0][0]
                    feature_rank = self.policy.feature_rank(obs_sample).item()
                else:
                    feature_rank = 0.

                print(f'Steps: {total_steps}\t'
                      f'Updates: {self.updates}\t'
                      f"Score: {np.mean(self._stats['score']):.3f}\t"
                      f"Loss: {np.mean(self._stats['loss_critic']):.5f}\t"
                      f"Entropy: {np.mean(self._stats['entropy']):.5f}\t"
                      f'Rank: {feature_rank}\t',
                      flush=True)

                if self.logging:
                    _data = dict(
                        steps=total_steps,
                        updates=self.updates,
                        feature_rank=feature_rank,
                    )
                    if score_eval is not None:
                        _data['score_eval'] = score_eval
                    for k, v in self._stats._dict.items():
                        _data[k] = np.mean(v)

                    if hasattr(self.transition_model, 'code_usage'):
                        freq = self.transition_model.code_usage.cpu().numpy()
                        bins = np.arange(len(freq) + 1)
                        _data['code_usage'] = wandb.Histogram(np_histogram=(freq, bins))

                        freq = self.transition_model.learned_prior.cpu().numpy()
                        _data['learned_prior'] = wandb.Histogram(np_histogram=(freq, bins))

                    wandb.log(_data)

        self.policy.eval()
        while recorded_steps < steps:
            action = self.act(torch.as_tensor(obs, dtype=torch.float, device=self.device))
            obs, _, _, _, info = env.step(action)
            if 'final_info' in info:
                for _info in info['final_info'][info['_final_info']]:
                    if 'episode' in _info:
                        _r, _l = _info['episode']['r'][0], _info['episode']['l'][0]
                        self.episode_stats.append((_r, _l))
                        recorded_steps += _l
