# This code has been adapted from the cleanRL repository

import os
import random
import time
from dataclasses import dataclass
from scipy.stats import gaussian_kde
from sklearn.neighbors import KernelDensity
import gymnasium as gym
import numpy as np
import torch
from utils import count_dying_relu,plot_and_log_latents_and_kde, compute_rank_from_features, define_wandb_metrics
from networks import QNetwork
import torch.nn.functional as F
import torch.optim as optim
import tyro
from stable_baselines3.common.atari_wrappers import (
    ClipRewardEnv,
    EpisodicLifeEnv,
    FireResetEnv,
    MaxAndSkipEnv,
    NoopResetEnv,
)
from stable_baselines3.common.buffers import ReplayBuffer
import wandb


@dataclass
class Args:
    exp_name: str = os.path.basename(__file__)[: -len(".py")]
    seed: int = 1
    torch_deterministic: bool = True
    cuda: bool = True
    wandb_project_name: str = "cleanRL"
    # Algorithm specific arguments
    env_id: str = "BreakoutNoFrameskip-v4"
    total_timesteps: int = 10000000
    learning_rate: float = 1e-4
    num_envs: int = 1
    buffer_size: int = 1000000
    gamma: float = 0.99
    tau: float = 1.0
    target_network_frequency: int = 1000
    batch_size: int = 32
    start_e: float = 1
    end_e: float = 0.01
    exploration_fraction: float = 0.10
    learning_starts: int = 80000
    train_frequency: int = 4
    conv_activation: str = 'relu'
    activation: str = 'tanh_HR'
    format: str = "png"
    latent_dim: int = 512


def make_env(env_id, seed, idx, capture_video, run_name):
    def thunk():
        if capture_video and idx == 0:
            env = gym.make(env_id, render_mode="rgb_array")
            env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
        else:
            env = gym.make(env_id)
        env = gym.wrappers.RecordEpisodeStatistics(env)
        env = NoopResetEnv(env, noop_max=30)
        env = MaxAndSkipEnv(env, skip=4)
        env = EpisodicLifeEnv(env)

        if "FIRE" in env.unwrapped.get_action_meanings():
            env = FireResetEnv(env)

        env = ClipRewardEnv(env)
        env = gym.wrappers.ResizeObservation(env, (84, 84))
        env = gym.wrappers.GrayScaleObservation(env)
        env = gym.wrappers.FrameStack(env, 4)

        env.action_space.seed(seed)
        return env

    return thunk


def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
    slope = (end_e - start_e) / duration
    return max(slope * t + start_e, end_e)


if __name__ == "__main__":
    args = tyro.cli(Args)
    assert args.num_envs == 1, "vectorized envs are not supported at the moment"
    run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
    wandb.init(
        project=args.wandb_project_name,
        name=run_name,
    )
    # Define wandb metrics for clean x-axis on plots
    define_wandb_metrics()

    # Seeding
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = args.torch_deterministic

    device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")

    # Environment setup
    envs = gym.vector.SyncVectorEnv(
        [make_env(args.env_id, args.seed + i, i, False, run_name) for i in range(args.num_envs)]
    )
    q_network = QNetwork(envs, args).to(device)
    target_network = QNetwork(envs, args).to(device)

    optimizer = optim.Adam(q_network.parameters(), lr=args.learning_rate, eps=1e-5)

    target_network.load_state_dict(q_network.state_dict())

    # Replay Buffer
    rb = ReplayBuffer(
        args.buffer_size,
        envs.single_observation_space,
        envs.single_action_space,
        device,
        optimize_memory_usage=True,
        handle_timeout_termination=False,
    )
    start_time = time.time()

    # Start the game
    obs, _ = envs.reset(seed=args.seed)
    for global_step in range(args.total_timesteps):
        # ALGO LOGIC: put action logic here
        epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_step)
        if random.random() < epsilon:
            actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])
        else:
            q_values = q_network(torch.Tensor(obs).to(device))[0]
            actions = torch.argmax(q_values, dim=1).cpu().numpy()

        # Execute the game and log data.
        next_obs, rewards, terminations, truncations, infos = envs.step(actions)

        # Record rewards for plotting purposes
        if "final_info" in infos:
            for info in infos["final_info"]:
                if info and "episode" in info:
                    print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
                    wandb.log({"episodic_return": info["episode"]["r"], "episodic_length": info["episode"]["l"],
                               "reward_step": global_step})

        # Save data to reply buffer; handle `final_observation`
        real_next_obs = next_obs.copy()

        # Count number of parameters in encoder
        total_params = sum(p.numel() for p in q_network.parameters())
        for idx, trunc in enumerate(truncations):
            if trunc:
                real_next_obs[idx] = infos["final_observation"][idx]
        rb.add(obs, real_next_obs, actions, rewards, terminations, infos)

        # Crucial step
        obs = next_obs

        # Training.
        if global_step > args.learning_starts:
            if global_step % args.train_frequency == 0:
                data = rb.sample(args.batch_size)
                with torch.no_grad():
                    target_max, _ = target_network(data.next_observations)[0].max(dim=1)
                    td_target = data.rewards.flatten()+ args.gamma * target_max * (1 - data.dones.flatten())
                old_val, tanh_activation = q_network(data.observations)
                old_val = old_val.gather(1, data.actions).squeeze()
                loss = F.mse_loss(td_target, old_val)

                # Calculate effective rank
                if global_step % 50000 == 0:
                    with torch.no_grad():
                        big_data = rb.sample(4096)
                        _, big_tanh_activation = q_network(big_data.observations)
                        effective_rank = compute_rank_from_features(big_tanh_activation.cpu().numpy())
                        del big_data, big_tanh_activation
                        print("Effective rank:", effective_rank)
                        difference_states1 = tanh_activation - torch.roll(tanh_activation, 1, dims=0)
                        latent_entropy = torch.norm(difference_states1, dim=1, p=2).mean()
                    print("SPS:", int(global_step / (time.time() - start_time)))
                    # Calculate dead ReLU's
                    if args.activation == 'relu' or args.activation == 'selu' or args.activation == 'relu_HR':
                        with torch.no_grad():
                            weighted, completely_dead_relu = count_dying_relu(tanh_activation)
                            wandb.log({"completely_dead_relu": completely_dead_relu,
                                       "weighted_dead_relu": weighted,
                                       "deadrelu_step": global_step})
                    wandb.log({"loss": loss, "q_values": old_val.mean().item(), "SPS": int(global_step / (time.time() - start_time)),
                               "global_step": global_step, "effective_rank": effective_rank})

                # Calculate dead neurons in the case of continuously differentiable activations (tanh, sigmoid)
                if global_step % 100000 == 0:
                    if args.activation != 'relu' and args.activation != 'selu' and args.activation != 'relu_HR':
                        with torch.no_grad():
                            plot_and_log_latents_and_kde(global_step, args.activation, q_network, data.observations, tanh_activation)

                            dead = 0
                            jitter_amount = 1e-5
                            for i in range(args.latent_dim):
                                data = tanh_activation[:, i].flatten().cpu().numpy()
                                data = data + np.random.normal(0, jitter_amount, size=data.shape)

                                # Initialize KDE with the Gaussian kernel and computed bandwidth
                                bw = gaussian_kde(data).scotts_factor() * data.std(ddof=1)
                                kde = KernelDensity(kernel='gaussian', bandwidth=bw)

                                # Fit the KDE model to the data
                                kde.fit(data[:, None])  # Reshape data for scikit-learn

                                # Define the axis for evaluation and compute log density
                                if args.activation == 'sigmoid' or args.activation == 'sigmoid_HR':
                                    # KDE goes from 0 to 1 for sigmoid
                                    x_axis = np.linspace(0, 1, 1000).reshape(-1, 1)
                                else:
                                    # KDE goes from -1 to 1 for tanh
                                    x_axis = np.linspace(-1, 1, 1000).reshape(-1, 1)
                                log_density = kde.score_samples(x_axis)
                                density = np.exp(log_density)
                                max_density = np.max(density)
                                # Check if the neuron is dead
                                if max_density >= 20:
                                    dead += 1

                            wandb.log({"completely_dead_tanh": dead/float(args.latent_dim), "logging_step": global_step})
                # optimize the model
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            # update target network
            if global_step % args.target_network_frequency == 0:
                for target_network_param, q_network_param in zip(target_network.parameters(), q_network.parameters()):
                    target_network_param.data.copy_(
                        args.tau * q_network_param.data + (1.0 - args.tau) * target_network_param.data
                    )

    envs.close()
    # Save the Q network
    torch.save(q_network.state_dict(), f'q_network_{args.activation}.pt')

