# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_atari_envpoolpy
import os
import random
import time
from collections import deque
from dataclasses import dataclass

from dataclasses import dataclass
from scipy.stats import gaussian_kde
from sklearn.neighbors import KernelDensity

import wandb
import envpool
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import tyro
from utils import count_dying_relu, compute_rank_from_features, RecordEpisodeStatistics, define_wandb_metrics_ppo
from networks import PPOAgent, PPO_HR_Agent
import seaborn as sns
import matplotlib.pyplot as plt


@dataclass
class Args:
    exp_name: str = os.path.basename(__file__)[: -len(".py")]
    seed: int = 1
    torch_deterministic: bool = True
    cuda: bool = True
    track: bool = False
    wandb_project_name: str = "envpool"
    wandb_entity: str = None
    # Algorithm specific arguments
    env_id: str = "SpaceInvaders-v5"
    total_timesteps: int = 10000000
    learning_rate: float = 2.5e-4
    sparse_learning_rate: float = 2.5e-4
    num_envs: int = 16   # Number of environments to run in parallel
    num_steps: int = 128
    anneal_lr: bool = True
    gamma: float = 0.99
    gae_lambda: float = 0.95
    num_minibatches: int = 4
    update_epochs: int = 4
    norm_adv: bool = True
    clip_coef: float = 0.1
    clip_vloss: bool = True
    ent_coef: float = 0.01
    vf_coef: float = 0.5
    max_grad_norm: float = 0.5
    target_kl: float = None
    # to be filled in runtime
    batch_size: int = 0
    minibatch_size: int = 0
    num_iterations: int = 0
    reward_scale: float = 1.0
    activation: str = "tanh_HR"
    latent_dim: int = 512


if __name__ == "__main__":
    args = tyro.cli(Args)
    args.batch_size = int(args.num_envs * args.num_steps)
    args.minibatch_size = int(args.batch_size // args.num_minibatches)
    args.num_iterations = args.total_timesteps // args.batch_size
    run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
    print(args.num_iterations,'num_iterations')

    wandb.init(
            project=args.wandb_project_name,
            config=vars(args),
            name=run_name,
    )
    # Define wandb metrics for proper plotting.
    define_wandb_metrics_ppo()

    # Seeding
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = args.torch_deterministic

    device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")

    # Env setup
    envs = envpool.make(
        args.env_id,
        env_type="gym",
        num_envs=args.num_envs,
        episodic_life=True,
        reward_clip=True,
        seed=args.seed,
    )
    envs.num_envs = args.num_envs
    envs.single_action_space = envs.action_space
    envs.single_observation_space = envs.observation_space
    envs = RecordEpisodeStatistics(envs)
    assert isinstance(envs.action_space, gym.spaces.Discrete), "only discrete action space is supported"

    if args.activation == "relu" or args.activation == "tanh" or args.activation =="sigmoid":
        agent = PPOAgent(envs, args.activation, args.latent_dim).to(device)
    elif args.activation == "tanh_HR":
        agent = PPO_HR_Agent(envs, args.latent_dim).to(device)

    optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5)

    # ALGO Logic: Storage setup
    obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device)
    actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device)
    logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device)
    rewards = torch.zeros((args.num_steps, args.num_envs)).to(device)
    dones = torch.zeros((args.num_steps, args.num_envs)).to(device)
    values = torch.zeros((args.num_steps, args.num_envs)).to(device)
    avg_returns = deque(maxlen=20)

    # TRY NOT TO MODIFY: start the game
    global_step = 0
    logstep = 0
    start_time = time.time()
    next_obs = torch.Tensor(envs.reset()).to(device)
    next_done = torch.zeros(args.num_envs).to(device)

    for iteration in range(1, args.num_iterations + 1):
        # Annealing the rate if instructed to do so.
        if args.anneal_lr:
            frac = 1.0 - (iteration - 1.0) / args.num_iterations
            lrnow = frac * args.learning_rate
            optimizer.param_groups[0]["lr"] = lrnow # TODO Run new experiments with this!

        for step in range(0, args.num_steps):
            global_step += args.num_envs
            obs[step] = next_obs
            dones[step] = next_done

            # ALGO LOGIC: action logic
            with torch.no_grad():
                action, logprob, _, value = agent.get_action_and_value(next_obs)
                values[step] = value.flatten()
            actions[step] = action
            logprobs[step] = logprob

            # TRY NOT TO MODIFY: execute the game and log data.
            next_obs, reward, next_done, info = envs.step(action.cpu().numpy())
            rewards[step] = torch.tensor(reward).to(device).view(-1)
            next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(next_done).to(device)

            for idx, d in enumerate(next_done):
                if d and info["lives"][idx] == 0:
                    print(f"global_step={global_step}, episodic_return={info['reward'][idx]}")
                    avg_returns.append(info["reward"][idx])
                    wandb.log(
                        {
                            "charts/avg_episodic_return": np.average(avg_returns),
                            "charts/episodic_return": info["reward"][idx],
                            "charts/episodic_length": info["elapsed_step"][idx],
                            "reward_step": global_step,
                        }
                    )

        # Bootstrap value if not done
        with torch.no_grad():
            next_value = agent.get_value(next_obs).reshape(1, -1)
            advantages = torch.zeros_like(rewards).to(device)
            lastgaelam = 0
            for t in reversed(range(args.num_steps)):
                if t == args.num_steps - 1:
                    nextnonterminal = 1.0 - next_done
                    nextvalues = next_value
                else:
                    nextnonterminal = 1.0 - dones[t + 1]
                    nextvalues = values[t + 1]
                delta = rewards[t]*args.reward_scale + args.gamma * nextvalues * nextnonterminal - values[t]
                advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam
            returns = advantages + values

        # Flatten the batch
        b_obs = obs.reshape((-1,) + envs.single_observation_space.shape)
        b_logprobs = logprobs.reshape(-1)
        b_actions = actions.reshape((-1,) + envs.single_action_space.shape)
        b_advantages = advantages.reshape(-1)
        b_returns = returns.reshape(-1)
        b_values = values.reshape(-1)

        # Optimizing the policy and value network
        b_inds = np.arange(args.batch_size)
        clipfracs = []
        for epoch in range(args.update_epochs):
            np.random.shuffle(b_inds)
            for start in range(0, args.batch_size, args.minibatch_size):
                end = start + args.minibatch_size
                mb_inds = b_inds[start:end]

                _, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions.long()[mb_inds])
                logratio = newlogprob - b_logprobs[mb_inds]
                ratio = logratio.exp()

                with torch.no_grad():
                    # calculate approx_kl http://joschu.net/blog/kl-approx.html
                    old_approx_kl = (-logratio).mean()
                    approx_kl = ((ratio - 1) - logratio).mean()
                    clipfracs += [((ratio - 1.0).abs() > args.clip_coef).float().mean().item()]

                mb_advantages = b_advantages[mb_inds]
                if args.norm_adv:
                    mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)

                # Policy loss
                pg_loss1 = -mb_advantages * ratio
                pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
                pg_loss = torch.max(pg_loss1, pg_loss2).mean()

                # Value loss
                newvalue = newvalue.view(-1)
                if args.clip_vloss:
                    v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2
                    v_clipped = b_values[mb_inds] + torch.clamp(
                        newvalue - b_values[mb_inds],
                        -args.clip_coef,
                        args.clip_coef,
                    )
                    v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2
                    v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
                    v_loss = 0.5 * v_loss_max.mean()
                else:
                    v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean()

                entropy_loss = entropy.mean()
                loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef

                optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
                optimizer.step()

            if args.target_kl is not None and approx_kl > args.target_kl:
                break

        y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy()
        var_y = np.var(y_true)
        explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y

        # Do this every 4 epochs
        if iteration % 20 == 0:
            with torch.no_grad():
                big_data = agent.get_pre_activation_value(b_obs[mb_inds])
                effective_rank = compute_rank_from_features(big_data.cpu().numpy())
                wandb.log({"effective_rank": effective_rank, "rank_step": global_step})
            if args.activation == 'relu':
                with torch.no_grad():
                    weighted_relu, completely_dead_relu = count_dying_relu(big_data)
                    wandb.log({"completely_dead_relu": completely_dead_relu, "deadrelu_step": global_step})
            elif args.activation == 'tanh_HR' or args.activation == 'tanh' or args.activation == 'sigmoid':
                with torch.no_grad():
                    dead = 0
                    weighted_dead = 0
                    jitter_amount = 1e-5
                    plt.figure()
                    sns.kdeplot(big_data.flatten().cpu().numpy())
                    wandb.log({"kde_plot": plt})
                    for i in range(args.latent_dim):
                        data = big_data[:, i].flatten().cpu().numpy()
                        data = data + np.random.normal(0, jitter_amount, size=data.shape)

                        # Initialize KDE with the Gaussian kernel and computed bandwidth
                        bw = gaussian_kde(data).scotts_factor() * data.std(ddof=1)
                        kde = KernelDensity(kernel='gaussian', bandwidth=bw)

                        # Fit the KDE model to the data
                        kde.fit(data[:, None])  # Reshape data for scikit-learn

                        # Define the axis for evaluation and compute log density
                        if args.activation == 'sigmoid':
                            x_axis = np.linspace(0, 1, 1000).reshape(0, 1)
                        else:
                            x_axis = np.linspace(-1, 1, 1000).reshape(-1, 1)
                        log_density = kde.score_samples(x_axis)
                        density = np.exp(log_density)
                        max_density = np.max(density)
                        if max_density >= 20:
                            dead += 1
                # Calculate the percentage of the tanh values that is abs(value) >=0.99
                values_dead = big_data.flatten().cpu().numpy()
                weighted_dead = np.sum(np.abs(values_dead) >= 0.99)
                weighted_dead = weighted_dead / float(len(values_dead))
                wandb.log({"completely_dead_tanh": dead / float(args.latent_dim), "logging_step": global_step})
            del big_data

        wandb.log(
            {
                "learning_rate": optimizer.param_groups[0]["lr"],
                "losses/value_loss": v_loss.item(),
                "losses/policy_loss": pg_loss.item(),
                "losses/entropy": entropy_loss.item(),
            },
            step=global_step,
        )
        print("SPS:", int(global_step / (time.time() - start_time)))

    envs.close()
