import os
import numpy as np
import torch
import wandb
from tensordict import TensorDictBase
from torchrl.envs.libs.vmas import VmasEnv
from torchrl.record.loggers import get_logger
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from src.models.gp_model import MultitaskGPModel


class Logger():
    
    def __init__(self, config, model=None, project="cooperative_marl"):
        self.config = config
        self.model = model
        self.system = self.config.system
        self.prev_lengthscales = None
        self.prev_outputscales = None
        self.logger = get_logger(
            logger_type=self.config.logger.backend,
            logger_name=os.getcwd(),
            experiment_name=self.config.exp_name,
            wandb_kwargs={
                "project": project,
                "entity": username,
            },
        )
        self.logger.log_hparams(self.config.toDict())

    def log_training(
        self,
        iteration: int,
        training_td: TensorDictBase,
        sampling_td: TensorDictBase,
        total_env_steps: int,
        total_policy_updates: int,
        total_halluc_env_steps: int = 0,
        total_model_updates: int = 0,
        nn_model_dataset_size: int = 0,
        gp_model_dataset_size: int = 0,
        replay_buffer_size: int = 0,
        beta: int = 0,
        evaluate_on_iteration: bool = False,
        halluc_env: VmasEnv = None,
        env: VmasEnv = None,
        nn_loss: float = 0,
        gp_loss: float = 0,
    ):
        if ("next", "agents", "reward") not in set(sampling_td.keys(True, True)):
            sampling_td.set(
                ("next", "agents", "reward"),
                sampling_td.get(("next", "reward"))
                .expand(sampling_td.get("agents").shape)
                .unsqueeze(-1),
            )
        if ("next", "agents", "episode_reward") not in set(sampling_td.keys(True, True)):
            sampling_td.set(
                ("next", "agents", "episode_reward"),
                sampling_td.get(("next", "episode_reward"))
                .expand(sampling_td.get("agents").shape)
                .unsqueeze(-1),
            )

        to_log = {
            f"train/learner/{key}": value.mean().item()
            for key, value in training_td.items()
        }

        if "info" in sampling_td.get("agents").keys():
            to_log.update(
                {
                    f"train/info/{key}": value.mean().item()
                    for key, value in sampling_td.get(("agents", "info")).items()
                }
            )

        reward = sampling_td.get(("next", "agents", "reward")).mean(-2)  # Mean over agents
        done = sampling_td.get(("next", "done"))
        if done.ndim > reward.ndim:
            done = done[..., 0, :]  # Remove expanded agent dim
        episode_reward = sampling_td.get(("next", "agents", "episode_reward")).mean(-2)[done]
        if len(episode_reward) != 0:
            to_log.update(
                {
                    "train/reward/episode_reward_min": episode_reward.min().item(),
                    "train/reward/episode_reward_mean": episode_reward.mean().item(),
                    "train/reward/episode_reward_max": episode_reward.max().item(),
                }
            )

        if self.system.startswith("Halluc"):
            to_log.update(
                {
                    "train/total_halluc_env_steps": total_halluc_env_steps,
                    "train/total_model_updates": total_model_updates,
                    "train/nn_model_dataset_size": nn_model_dataset_size,
                    "train/gp_model_dataset_size": gp_model_dataset_size,
                    "train/replay_buffer_size": replay_buffer_size,
                    "train/beta": beta,
                    "train_gp/nn_model_loss": nn_loss,
                    "train_gp/gp_model_loss": gp_loss,
                }
            )

        if self.system.startswith("Halluc") and isinstance(self.model, MultitaskGPModel):
            noises = None
            if self.model.likelihood.has_global_noise:
                global_noise = self.model.likelihood.noise.detach().clone().item()
                to_log.update({"train_gp/global_noise": global_noise})
                noises = torch.ones((self.model.output_dim), device=self.model.device) * global_noise
            if self.model.likelihood.has_task_noise and self.model.likelihood.rank == 0:
                task_noises = self.model.likelihood.task_noises.detach().clone()
                noises = torch.abs(noises) + torch.abs(task_noises) if noises is not None else torch.abs(task_noises)
                for i in range(self.model.output_dim):
                    to_log.update({f"train_gp/task_noise_{i}": task_noises[i]})
            outputscales = self.model.get_outputscales()
            for i in range(self.model.output_dim):
                to_log.update({f"train_gp/signal_to_noise_gp_{i}": outputscales[i] / (outputscales[i] + noises[i])})
                to_log.update({f"train_gp/outputscale_{i}": outputscales[i]})
                if hasattr(self.model.gp_model, "covar_module") or i != 4:
                    lengthscales = self.model.get_lengthscales(i)
                    for j in range(self.model.input_dim):
                        to_log.update({f"train_gp/lengthscales/gp_{i}_input_{j}": lengthscales[j].item()})

        to_log.update(
            {
                "train/reward/reward_min": reward.min().item(),
                "train/reward/reward_mean": reward.mean().item(),
                "train/reward/reward_max": reward.max().item(),
                "train/training_iteration": iteration,
                "train/total_env_steps": total_env_steps,
                "train/total_policy_updates": total_policy_updates,
            }
        )

        if evaluate_on_iteration and halluc_env is not None and hasattr(halluc_env, "frames") and len(halluc_env.frames) > 0:
            # Show the hallucinated trajectory as a gif
            vid = torch.tensor( # sampling_td is the hallucinated rollout
                np.transpose(halluc_env.frames[: sampling_td[0].batch_size[0]], (0, 3, 1, 2)),
                dtype=torch.uint8,
                ).unsqueeze(0)
            to_log.update({"train/rollouts/halluc_rollout": wandb.Video(vid, fps=2 / halluc_env.world.dt, format="mp4")})
        
        if evaluate_on_iteration and env is not None and hasattr(env, "frames") and len(env.frames) > 0:
            # Show the hallucinated trajectory as a gif
            vid = torch.tensor( # sampling_td is the hallucinated rollout
                np.transpose(env.frames[: sampling_td[0].batch_size[0]], (0, 3, 1, 2)),
                dtype=torch.uint8,
                ).unsqueeze(0)
            to_log.update({"train/rollouts/real_env_rollout": wandb.Video(vid, fps=2 / env.world.dt, format="mp4")})
        
        self.logger.experiment.log(to_log, commit=True)

        return to_log

    def log_evaluation(
        self,
        iteration: int,
        rollouts: TensorDictBase,
        env_test: VmasEnv,
        total_env_steps: int,
        total_policy_updates: int,
        total_halluc_env_steps: int = 0,
        total_model_updates: int = 0,
        rollout_model_input: torch.Tensor = None,
    ):
        rollouts = list(rollouts.unbind(0))
        for k, r in enumerate(rollouts):
            next_done = r.get(("next", "done")).sum(
                tuple(range(r.batch_dims, r.get(("next", "done")).ndim)),
                dtype=torch.bool,
            )
            done_index = next_done.nonzero(as_tuple=True)[0][0]  # First done index for this traj
            rollouts[k] = r[: done_index + 1]

        rewards = [td.get(("next", "agents", "reward")).sum(0).mean() for td in rollouts]
        to_log = {
            "eval/episode_reward_min": min(rewards),
            "eval/episode_reward_max": max(rewards),
            "eval/episode_reward_mean": sum(rewards) / len(rollouts),
            "eval/episode_len_mean": sum([td.batch_size[0] for td in rollouts]) / len(rollouts),
            "train/training_iteration": iteration,
            "train/total_env_steps": total_env_steps,
            "train/total_policy_updates": total_policy_updates,
        }

        if self.system.startswith("Halluc"):
            to_log.update(
                {
                    "train/total_halluc_env_steps": total_halluc_env_steps,
                    "train/total_model_updates": total_model_updates,
                }
            )
            if isinstance(self.model, MultitaskGPModel):
                for i in range(self.model.output_dim):
                    to_log.update({f"eval_covariances/gp_covariance_dim_{i}": wandb.Image(self.get_covariance(rollout_model_input, dim=i))})

        self.logger.experiment.log(to_log, commit=False)

        vid = torch.tensor(
            np.transpose(env_test.frames[: rollouts[0].batch_size[0]], (0, 3, 1, 2)),
            dtype=torch.uint8,
        ).unsqueeze(0)
        self.logger.experiment.log(
            {
                "eval/video": wandb.Video(vid, fps=2 / env_test.world.dt, format="mp4"),
            },
            commit=True,
        )

    def get_covariance(self, rollout, dim=0):
        fig, ax = plt.subplots(1, 1, figsize=(30, 30))
        if self.model.use_separate_reward_cov:
            covariances = pd.DataFrame(self.model.gp_model.kernels[dim](rollout).to_dense().detach().clone().cpu().numpy())
        else:
            covariances = pd.DataFrame(self.model.gp_model.covar_module(rollout).to_dense()[0, dim].detach().clone().cpu().numpy())
        sns.heatmap(covariances, cmap='coolwarm', linecolor='white', linewidths=1, ax=ax)
        ax.set_title(f"Dimension {dim}")
        fig.tight_layout()
        return fig
    
    def finish(self):
        self.logger.finish()