import torch
import numpy as np
import matplotlib.pyplot as plt
import wandb
import seaborn as sns
import gym


def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer


def count_dying_relu(relu_outputs):

    # Compute a boolean tensor indicating for each neuron and data point if it's dead
    is_dead_per_data_point = (relu_outputs == 0)
    # Compute the fraction of dead data points for each neuron
    dead_fractions = is_dead_per_data_point.float().mean(dim=0)
    # Weighted count of dying ReLUs
    weighted_count = dead_fractions.sum()
    # Count completely dead ReLUs
    completely_dead = torch.sum(dead_fractions == 1.0).item()

    return weighted_count.item(), completely_dead


def plot_and_log_latents_and_kde(iterations, activation, encoder, STATE, final_latent):
    # check latent dimension
    if final_latent.shape[1] ==512:
        dim_counter = 30
    elif final_latent.shape[1] == 1024:
        dim_counter = 60
    elif final_latent.shape[1] == 256:
        dim_counter = 15
    if activation == 'sigmoid_HR' or activation == 'tanh_HR' or activation == 'plustanh' or activation == 'double_HR':
            with torch.no_grad():
                # Plotting the KDE
                plot1, plot2 = encoder.get_individual_representations(STATE)
                plot1_flat = plot1.flatten().cpu().numpy()
                plot2_flat = plot2.flatten().cpu().numpy()
                plt.figure(figsize=(10, 6))
                sns.kdeplot(plot1_flat, bw_adjust=0.5, color='blue')

                plt.title('Kernel Density Estimation of Extra Rep.')
                plt.xlabel('Output')
                plt.ylabel('Density')
                if activation == 'sigmoid_HR':
                    plt.xlim(0, 1)
                else:
                    plt.xlim(-1, 1)
                wandb.log({"extra_kde": plt,
                           'plot_step': iterations})
                plt.close()
                plt.clf()
                plt.figure(figsize=(10, 6))
                sns.kdeplot(plot2_flat, bw_adjust=0.5, color='red')
                plt.title('Kernel Density Estimation of Base Rep.')
                plt.xlabel('Tanh Output')
                plt.ylabel('Density')
                if activation == 'sigmoid_HR':
                    plt.xlim(0, 1)
                else:
                    plt.xlim(-1, 1)
                wandb.log({"base_kde": plt,
                           'plot_step': iterations})
                plt.close()
                plt.clf()

            plt.figure(figsize=(10, 10))
            for j in range(16):
                plt.subplot(5, 5, j + 1)
                sns.kdeplot(plot1[:, j * dim_counter].flatten().cpu().numpy(), bw_adjust=0.5, color='purple', label='mask')
                sns.kdeplot(plot2[:, j * dim_counter].flatten().cpu().numpy(), bw_adjust=0.5, color='green', label='base')
                if activation == 'sigmoid_HR':
                    plt.xlim(0, 1)
                else:
                    plt.xlim(-1, 1)
                plt.subplots_adjust(hspace=0.2)  # 0.53
            plt.subplots_adjust(wspace=0.2)
            wandb.log({"16_neuron_kde_plot_mask_base": plt,
                       'plot_step': iterations})
            plt.close()
            plt.clf()

    plt.figure(figsize=(10/1.5, 6/1.5))
    sns.kdeplot(final_latent.flatten().cpu().numpy(), bw_adjust=0.5, color='green')
    plt.title('Kernel Density Estimation of Final Latent')
    plt.xlabel('Latent Output')
    plt.ylabel('Density')
    if activation == 'sigmoid' or activation == 'sigmoid_HR':
        plt.xlim(0, 1)  # tanh outputs range from -1 to 1
    else:
        plt.xlim(-1, 1)
    wandb.log({"final_kde": plt,
               'plot_step': iterations})
    plt.close()
    plt.clf()

    plt.figure(figsize=(10, 10))
    for j in range(16):
        plt.subplot(5, 5, j + 1)
        sns.kdeplot(final_latent[:, j * dim_counter].flatten().cpu().numpy(), bw_adjust=0.5, color='purple')
        if activation == 'sigmoid' or activation == 'sigmoid_HR':
            plt.xlim(0, 1)
        else:
            plt.xlim(-1, 1)
    plt.subplots_adjust(hspace=0.2)  # 0.53
    plt.subplots_adjust(wspace=0.2)
    wandb.log({"16_neuron_kde_plot": plt,
               'plot_step': iterations})
    plt.close()
    plt.clf()


def compute_rank_from_features(feature_matrix, rank_delta=0.01):
    # See: Kumar et al, : https://arxiv.org/abs/2010.14498
    sing_values = np.linalg.svd(feature_matrix, compute_uv=False)
    cumsum = np.cumsum(sing_values)
    nuclear_norm = np.sum(sing_values)
    approximate_rank_threshold = 1.0 - rank_delta
    threshold_crossed = (
    cumsum >= approximate_rank_threshold * nuclear_norm)
    effective_rank = sing_values.shape[0] - np.sum(threshold_crossed) + 1
    return effective_rank

def define_wandb_metrics():
    wandb.define_metric("global_step")
    wandb.define_metric("plot_step")
    wandb.define_metric("normal_step")
    wandb.define_metric("logging_step")
    wandb.define_metric("episodic_return", step_metric="reward_step")
    wandb.define_metric("episodic_length", step_metric="reward_step")
    wandb.define_metric("loss", step_metric="global_step")
    wandb.define_metric("q_values", step_metric="global_step")
    wandb.define_metric("SPS", step_metric="global_step")
    wandb.define_metric("mask_kde", step_metric="plot_step")
    wandb.define_metric("base_kde", step_metric="plot_step")
    wandb.define_metric("final_kde", step_metric="plot_step")
    wandb.define_metric("16_neuron_kde_plot", step_metric="plot_step")
    wandb.define_metric("16_neuron_kde_plot_mask_base", step_metric="plot_step")
    wandb.define_metric("dead_neurons", step_metric="logging_step")
    wandb.define_metric("effective_rank", step_metric="global_step")
    wandb.define_metric("dead_relus", step_metric="deadrelu_step")
    wandb.define_metric("bias_contribution", step_metric="global_step")
    wandb.define_metric("mean_q_values", step_metric="global_step")
    wandb.define_metric("completely_dead_tanh", step_metric="logging_step")


class RecordEpisodeStatistics(gym.Wrapper):
    def __init__(self, env, deque_size=100):
        super().__init__(env)
        self.num_envs = getattr(env, "num_envs", 1)
        self.episode_returns = None
        self.episode_lengths = None

    def reset(self, **kwargs):
        observations = super().reset(**kwargs)
        self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
        self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
        self.lives = np.zeros(self.num_envs, dtype=np.int32)
        self.returned_episode_returns = np.zeros(self.num_envs, dtype=np.float32)
        self.returned_episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
        return observations


def define_wandb_metrics_ppo():
    wandb.define_metric("reward_step")
    wandb.define_metric("charts/avg_episodic_return", step_metric="reward_step")
    wandb.define_metric("charts/episodic_return", step_metric="reward_step")
    wandb.define_metric("charts/episodic_length", step_metric="reward_step")
    wandb.define_metric("deadrelu_step")
    wandb.define_metric("weighted_dying_relu", step_metric="deadrelu_step")
    wandb.define_metric("completely_dead_relu", step_metric="deadrelu_step")
    wandb.define_metric("logging_step")
    wandb.define_metric("completely_dead_tanh", step_metric="logging_step")
    wandb.define_metric("weighted_dying_tanh", step_metric="logging_step")
    wandb.define_metric("rank_step")
    wandb.define_metric("effective_rank", step_metric="rank_step")