import pickle
import time
from dataclasses import asdict, dataclass
from typing import Any, overload

import gymnasium as gym
import miniworld
import numpy as np
import scipy.stats as stats
import torch
import torch.nn as nn
import torch.optim as optim
from torch import Tensor
from torch.distributions.categorical import Categorical
from tqdm import tqdm

from args import (
    AdversarialTrainingConfig,
    DatasetConfig,
    EvalConfig,
    LoggingConfig,
    ModelConfig,
    PPOMWConfig,
    SeedConfig,
    get_adv_trained_model_name,
    get_model_name,
    parse_args_to_dataclass,
)
from mdp.mdp_attacker import MDPAttacker, MDPGridRandomAttacker
from mdp.mdp_dataset import MW_IMAGE_SHAPE, process_miniworld_images
from mdp.miniworld_env import MW_N_ACTIONS, MDPMiniworldAttacker, MiniworldEnv
from util.argparser_dataclass import parse_args_to_dataclass
from util.datastore import Datastore
from util.logger import PrintLogger, WandbLogger
from util.seed import set_seed

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer


class PPOAgent(nn.Module):
    n_actions: int
    network: nn.Module
    actor: nn.Module
    critic: nn.Module

    def __init__(self, n_actions: int):
        super().__init__()
        self.n_actions = n_actions
        self.network = nn.Sequential(
            # 3 x 25 x 25
            layer_init(nn.Conv2d(3, 32, 3, stride=2)),
            # 32 x 12 x 12
            nn.ReLU(),
            layer_init(nn.Conv2d(32, 64, 3, stride=1)),
            # 64 x 10 x 10
            nn.ReLU(),
            layer_init(nn.Conv2d(64, 64, 3, stride=1)),
            # 64 x 8 x 8
            nn.ReLU(),
            nn.Flatten(-3),
            layer_init(nn.Linear(64 * 8 * 8, 512)),
            # 512
            nn.ReLU(),
        )
        self.actor = layer_init(nn.Linear(512, self.n_actions), std=0.01)
        self.critic = layer_init(nn.Linear(512, 1), std=1)

    def get_value(self, x) -> Tensor:
        return self.critic(self.network(x / 255.0))

    def get_action_and_value(self, x, action=None) -> tuple[Tensor, Tensor, Tensor, Tensor]:
        hidden = self.network(x / 255.0)
        logits = self.actor(hidden)
        probs = Categorical(logits=logits)
        if action is None:
            action = probs.sample()
        return action, probs.log_prob(action), probs.entropy(), self.critic(hidden)


@dataclass
class DatasetPPO:
    obs: Tensor
    actions: Tensor
    logprobs: Tensor
    rewards: Tensor
    rewards_original: Tensor
    dones: Tensor
    values: Tensor


class MiniworldEnvForPPO(MiniworldEnv):

    @classmethod
    def sample(cls, n_envs: int, n_steps: int, device=None, seed=0) -> "MiniworldEnvForPPO":
        envs = []
        for env_idx in range(n_envs):
            env = gym.make("MiniWorld-OneRoomS6FastMultiFourBoxesFixedInit-v0", max_episode_steps=n_steps)  # type: ignore
            # env: OneRoomS6FastMulti
            env.unwrapped.set_task(env_id=8000 + seed + env_idx)  # type: ignore
            envs.append(env)
        return MiniworldEnvForPPO(envs, n_steps, device=device)

    @overload
    def deploy_ppo(
        self,
        agents: list[PPOAgent],
        *,
        pbar_desc: str | None = None,
        force_show_progress: bool | None = None,
        **kwargs,
    ) -> DatasetPPO: ...

    @overload
    def deploy_ppo(
        self,
        agents: list[PPOAgent],
        attacker: MDPAttacker | None,
        eps_episodes: float,
        eps_steps: float,
        *,
        pbar_desc: str | None = None,
        force_show_progress: bool | None = None,
        **kwargs,
    ) -> DatasetPPO: ...

    def deploy_ppo(
        self,
        agents: list[PPOAgent],
        attacker: MDPAttacker | None = None,
        eps_episodes: float | None = None,
        eps_steps: float | None = None,
        *,
        pbar_desc: str | None = None,
        force_show_progress: bool | None = None,
        **kwargs,
    ) -> DatasetPPO:
        """Deploy PPO in the environment with corruption. Returns the trajectories of the deployment."""
        self.reset()

        if attacker is None:
            self.attacker = None
        else:
            assert eps_episodes is not None and eps_steps is not None, "eps_episodes and eps_steps must be set"
            self._set_attacker(attacker, eps_episodes, eps_steps)

        n_envs = self.n_envs
        n_steps = self.n_steps
        obs = torch.zeros((n_envs, n_steps, *MW_IMAGE_SHAPE), device=device)
        actions = torch.zeros((n_envs, n_steps, MW_N_ACTIONS), device=device)
        logprobs = torch.zeros((n_envs, n_steps), device=device)
        rewards = torch.zeros((n_envs, n_steps), device=device)
        rewards_original = torch.zeros((n_envs, n_steps), device=device)
        dones = torch.zeros((n_envs, n_steps), device=device)
        values = torch.zeros((n_envs, n_steps), device=device)
        dones[:, -1] = 1

        if (self.n_envs < 10000 or self.n_steps < 100) and not force_show_progress:
            loop = lambda x: range(x)
        else:
            loop = lambda x: tqdm(range(x), desc=(f"{pbar_desc} " if pbar_desc is not None else "") + "Deploy - PPO")

        images = [env.render() for env in self.envs]
        # states = self.states.clone()

        for step in loop(self.n_steps):
            obs[:, step] = process_miniworld_images(images)
            for env_idx, (agent, image) in enumerate(zip(agents, obs[:, step])):
                with torch.no_grad():
                    action, logprob, _, value = agent.get_action_and_value(image)
                    actions[env_idx, step, action] = 1
                    logprobs[env_idx, step] = logprob
                    values[env_idx, step] = value.flatten()

            rewards_env, rewards_original_cur, infos, _ = self.step(actions[:, step, :], **kwargs)
            # states_next = self.states.clone()
            images_next = [env.render() for env in self.envs]

            rewards[:, step] = rewards_env.squeeze(-1)
            rewards_original[:, step] = rewards_original_cur.squeeze(-1)

            # controller.append(states, images, actions, rewards, states_next, images_next, rewards_original, {"infos": infos})

            images = images_next
            # states = states_next

        dataset = DatasetPPO(obs, actions, logprobs, rewards, rewards_original, dones, values)

        return dataset


def main(
    logging_config: LoggingConfig,
    dataset_config: DatasetConfig,
    model_config: ModelConfig,
    eval_config: EvalConfig,
    adv_train_config: AdversarialTrainingConfig,
    seed_config: SeedConfig,
    ppo_config: PPOMWConfig,
) -> None:

    run_name = get_adv_trained_model_name(dataset_config, model_config, eval_config, adv_train_config)

    if logging_config.log == "wandb":
        logger = WandbLogger(
            run_name,
            config={
                **asdict(dataset_config),
                **asdict(model_config),
                **asdict(eval_config),
                **asdict(adv_train_config),
            },
        )
    else:
        logger = PrintLogger(run_name, "Step")

    assert adv_train_config.attacker_against in ["clean", "unifrand"]
    alg_against = adv_train_config.attacker_against

    online_alg_episodes = 100  # hardcoded for now

    n_envs = eval_config.n_envs_eval
    n_steps = dataset_config.context_len
    n_steps_eval = eval_config.n_steps_eval if eval_config.n_steps_eval is not None else dataset_config.context_len
    # n_states = dataset_config.n_states
    n_actions = dataset_config.n_actions
    victim_alg = "ppo"

    # state_dim = 2

    setup_name = get_adv_trained_model_name(dataset_config, model_config, eval_config, adv_train_config, print_against=False)

    seed = seed_config.seed
    set_seed(seed)

    env = MiniworldEnvForPPO.sample(n_envs, n_steps_eval, device=device, seed=60000 + seed)
    attacker = None if alg_against == "clean" else MDPGridRandomAttacker(n_envs, 7, adv_train_config.max_poison_diff, device=device)

    agents = []
    optimizers = []
    for _ in range(n_envs):
        agent = PPOAgent(n_actions).to(device=device)
        optimizer = optim.Adam(agent.parameters(), lr=model_config.lr, eps=1e-5)
        agents.append(agent)
        optimizers.append(optimizer)
    anneal_lr = True

    # parallelize by n_envs
    # num_steps = 100
    # num_envs = 1

    for episode in tqdm(range(online_alg_episodes), desc=f"Learning Online - {victim_alg.upper()}"):
        if anneal_lr:
            frac = 1.0 - (episode - 1.0) / n_steps
            lrnow = frac * ppo_config.ppo_lr
            for optimizer in optimizers:
                optimizer.param_groups[0]["lr"] = lrnow

        dataset_victim = env.deploy_ppo(agents, attacker, adv_train_config.eps_episodes, adv_train_config.eps_steps, force_show_progress=True)

        metrics = train_ppo(agents, optimizers, dataset_victim, eval_config, dataset_config, ppo_config)
        logger.log(metrics)

    dataset_victim = env.deploy_ppo(agents, attacker, adv_train_config.eps_episodes, adv_train_config.eps_steps, force_show_progress=True)
    rewards_victim: np.ndarray = dataset_victim.rewards_original.numpy(force=True)

    rewards_alg = rewards_victim.sum(-1).mean(-1)
    print(f"Seed {seed} {victim_alg} reward against {alg_against}:")
    print(f"{rewards_alg:.3f}")
    print()

    datastore = Datastore()
    datastore.store_table_value(f"{dataset_config.env}_reward_epss{adv_train_config.eps_steps:.1f}", victim_alg, alg_against, rewards_alg, seed)
    print(f"Updated '{datastore.path}'.")

    results_filename = f"models/adv/{setup_name}/attacker_against_{alg_against}_{victim_alg}_evals_seed{seed}.pkl"
    with open(results_filename, "wb") as f:
        pickle.dump(rewards_victim, f)
    print(f"Saved to '{results_filename}'.")


def train_ppo(
    agents: list[PPOAgent], optimizers: list[optim.Optimizer], dataset: DatasetPPO, eval_config: EvalConfig, dataset_config: DatasetConfig, ppo_config: PPOMWConfig
) -> dict[str, Any]:
    start_time = time.time()

    n_steps = eval_config.n_steps_eval if eval_config.n_steps_eval is not None else dataset_config.context_len
    n_envs = eval_config.n_envs_eval
    batch_size = 250
    minibatch_size = 64

    obs = dataset.obs
    actions = dataset.actions
    logprobs = dataset.logprobs
    rewards = dataset.rewards
    # rewards_original = dataset.rewards_original
    dones = dataset.dones
    values = dataset.values

    # bootstrap value if not done
    with torch.no_grad():
        advantages = torch.zeros_like(rewards)
        lastgaelam = 0
        for t in reversed(range(n_steps)):
            if t == n_steps - 1:
                nextnonterminal = 0
                nextvalues = 0
            else:
                nextnonterminal = 1.0 - dones[:, t + 1]
                nextvalues = values[:, t + 1]
            delta = rewards[:, t] + ppo_config.gamma * nextvalues * nextnonterminal - values[:, t]
            advantages[:, t] = lastgaelam = delta + ppo_config.gamma * ppo_config.gae_lambda * nextnonterminal * lastgaelam
        returns = advantages + values

    # flatten the batch
    b_obs = obs.detach()  # .reshape((n_envs, -1, *MW_IMAGE_SHAPE))
    b_logprobs = logprobs.detach()  # .reshape((n_envs, -1))
    b_actions = actions.detach()  # .reshape((n_envs, -1, 4))
    b_advantages = advantages.detach()  # .reshape((n_envs, -1))
    b_returns = returns.detach()  # .reshape((n_envs, -1))
    b_values = values.detach()  # .reshape((n_envs, -1))

    # Optimizing the policy and value network
    # batch_size: int = 1024
    clipfracs = []
    for epoch in range(ppo_config.update_epochs):
        b_inds = torch.randperm(batch_size, device=device)

        # batch_size: int = 1024
        # minibatch_size: int = 256
        for start in range(0, batch_size, minibatch_size):
            end = start + minibatch_size
            mb_inds = b_inds[start:end]

            logratios = torch.zeros((n_envs, mb_inds.shape[0]), device=device)
            ratios = torch.zeros((n_envs, mb_inds.shape[0]), device=device)
            entropies = torch.zeros((n_envs, mb_inds.shape[0]), device=device)
            newvalues = torch.zeros((n_envs, mb_inds.shape[0]), device=device)
            for env_idx, agent in enumerate(agents):
                _, newlogprob, entropy_single, newvalue_single = agent.get_action_and_value(b_obs[env_idx, mb_inds], b_actions.long()[env_idx, mb_inds].argmax(-1))
                logratio_single = newlogprob - b_logprobs[env_idx, mb_inds]
                ratio_single = logratio_single.exp()

                logratios[env_idx] = logratio_single
                ratios[env_idx] = ratio_single
                entropies[env_idx] = entropy_single
                newvalues[env_idx] = newvalue_single.squeeze(-1)

            with torch.no_grad():
                # calculate approx_kl http://joschu.net/blog/kl-approx.html
                old_approx_kl = (-logratios).mean()
                approx_kl = ((ratios - 1) - logratios).mean()
                clipfracs += [((ratios - 1.0).abs() > ppo_config.clip_coef).float().mean().item()]

            mb_advantages = b_advantages[:, mb_inds]
            if ppo_config.norm_adv:
                mb_advantages = (mb_advantages - mb_advantages.mean(dim=-1, keepdim=True)) / (mb_advantages.std(dim=-1, keepdim=True) + 1e-8)

            # Policy loss
            pg_loss1 = -mb_advantages * ratios
            pg_loss2 = -mb_advantages * torch.clamp(ratios, 1 - ppo_config.clip_coef, 1 + ppo_config.clip_coef)
            pg_loss = torch.maximum(pg_loss1, pg_loss2).mean(dim=-1, keepdim=True)

            # Value loss
            newvalues = newvalues.view((n_envs, -1))
            if ppo_config.clip_vloss:
                v_loss_unclipped = (newvalues - b_returns[:, mb_inds]) ** 2
                v_clipped = b_values[:, mb_inds] + torch.clamp(
                    newvalues - b_values[:, mb_inds],
                    -ppo_config.clip_coef,
                    ppo_config.clip_coef,
                )
                v_loss_clipped = (v_clipped - b_returns[:, mb_inds]) ** 2
                v_loss_max = torch.maximum(v_loss_unclipped, v_loss_clipped)
                v_loss = 0.5 * v_loss_max.mean(dim=-1, keepdim=True)
            else:
                v_loss = 0.5 * ((newvalues - b_returns[:, mb_inds]) ** 2).mean(dim=-1, keepdim=True)

            entropy_losses = entropies.mean(dim=-1, keepdim=True)
            losses = pg_loss - ppo_config.ent_coef * entropy_losses + v_loss * ppo_config.vf_coef

            for optimizer in optimizers:
                optimizer.zero_grad()

            for i, loss in enumerate(losses):
                retain_graph = i < len(losses) - 1
                loss.backward(retain_graph=retain_graph)  # retain graph because vmapped along 0th dim

            for agent, optimizer in zip(agents, optimizers):
                nn.utils.clip_grad_norm_(agent.parameters(), ppo_config.max_grad_norm)
                optimizer.step()

        if ppo_config.target_kl is not None and approx_kl > ppo_config.target_kl:
            break

    with torch.no_grad():
        y_pred, y_true = b_values, b_returns
        var_y = torch.var(y_true, dim=-1)
        explained_var = torch.nan if torch.any(var_y == 0) else (1 - torch.var(y_true - y_pred, dim=-1) / var_y).mean().item()

    metrics = {
        "charts/learning_rate": optimizers[0].param_groups[0]["lr"],
        "losses/value_loss": v_loss.detach().mean().item(),
        "losses/policy_loss": pg_loss.detach().mean().item(),
        "losses/entropy": entropy_losses.detach().mean().item(),
        "losses/old_approx_kl": old_approx_kl.item(),
        "losses/approx_kl": approx_kl.item(),
        "losses/clipfrac": np.mean(clipfracs),
        "losses/explained_variance": explained_var,
        "charts/episodes_per_second": 1 / (time.time() - start_time),
    }
    print("episodes_per_second:", 1 / (time.time() - start_time))
    return metrics


if __name__ == "__main__":
    logging_config, dataset_config, model_config, eval_config, adversarial_training_config, seed_config, ppo_config = parse_args_to_dataclass(
        (LoggingConfig, DatasetConfig, ModelConfig, EvalConfig, AdversarialTrainingConfig, SeedConfig, PPOMWConfig)
    )

    print(logging_config, dataset_config, model_config, eval_config, adversarial_training_config, seed_config, ppo_config, sep="\n")

    time_start = time.time()
    main(logging_config, dataset_config, model_config, eval_config, adversarial_training_config, seed_config, ppo_config)
    time_end = time.time()

    print(f"Total runtime: {time_end - time_start:.2f} s")
