import torch
from pytorch_lightning.callbacks import Callback

import os
from src.utils import get_wandb_logger

import sys
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', '..', 'src'))
from utils.viz import plot_pcs, plot_corr
import wandb
from utils.distance import hungarian, chamfer
import matplotlib.pyplot as plt 

class EvalCallback(Callback):
    def __init__(self, output_dir, train_log_freq=50, val_batch_size=16):
        super(EvalCallback, self).__init__()
        
        self.train_data = {}
        self.val_data = {}
        self.test_data = {}
        self.output_dir = output_dir
        self.train_log_freq = train_log_freq
        self.val_batch_size = val_batch_size

        os.makedirs(self.output_dir, exist_ok=True)
        
    def on_train_batch_end(self, trainer, pl_module, output, batch, batch_idx, dataloader_idx):
        global_rank = pl_module.global_rank

        if pl_module.current_epoch % self.train_log_freq == 0:
            if global_rank == 0:
                for key in output.keys():
                    if key != "loss":
                        if key in self.train_data:
                            self.train_data[key] = torch.cat((self.train_data[key], output[key].detach().cpu()), dim=0)
                        else:
                            self.train_data[key] = output[key].detach().cpu()

    def on_validation_batch_end(self, trainer, pl_module, output, batch, batch_idx, dataloader_idx):
        global_rank = pl_module.global_rank

        if global_rank == 0:
            with torch.set_grad_enabled(True):
                # Compute loss and grad from algorithm
                temp_target = output['target'].detach().clone() 
                temp_target.requires_grad = True
                temp_target.retain_grad()

                preds = hungarian(output['source'], temp_target)
                output['true_dist'] = preds.clone().detach()
                loss = preds.mean()
                loss.backward()
                output["true_grad"] = temp_target.grad

                # Compute loss and grad from model
                temp_target = output['target'].detach().clone() 
                temp_target.requires_grad = True
                temp_target.retain_grad()

                preds = pl_module(output['source'], temp_target)
                loss = preds.mean()
                loss.backward()
                output["pred_grad"] = temp_target.grad

                # pl_module.zero_grad()

                # del temp_target, 

            
            for key in output.keys():
                if key in self.val_data:
                    self.val_data[key] = torch.cat((self.val_data[key], output[key].cpu()), dim=0)
                else:
                    self.val_data[key] = output[key].cpu()

    def on_train_epoch_end(self, trainer, pl_module, outputs=None):

        global_rank = pl_module.global_rank

        if pl_module.current_epoch % self.train_log_freq == 0:
            if global_rank == 0:
                # Collect data from all batches
                for key in self.train_data.keys():
                    self.train_data[key] = self.train_data[key].cpu()

                sorted_idxs = list(torch.sort(self.train_data["dist"])[1].numpy())
                t = len(sorted_idxs)
                idxs = [sorted_idxs[0], 
                            sorted_idxs[t//8], 
                            sorted_idxs[t//4], 
                            sorted_idxs[t//2], 
                            sorted_idxs[3*t//4],
                            sorted_idxs[5*t//8],
                            sorted_idxs[-1]]

                logger = get_wandb_logger(trainer=trainer)
                experiment = logger.experiment

                fig = plot_pcs(self.train_data["source"], self.train_data["target"], idxs, self.train_data["preds"])
                experiment.log({'train/train_pred':wandb.Image(fig)})
                fig.savefig(os.path.join(self.output_dir, "train_pred_{}.pdf".format(pl_module.current_epoch)))

                fig = plot_corr(self.train_data["dist"], self.train_data["preds"])
                experiment.log({'train/train_corr':wandb.Image(fig)})
                fig.savefig(os.path.join(self.output_dir, "train_corr_{}.pdf".format(pl_module.current_epoch)))

        self.reset(mode="train")

    def on_validation_epoch_end(self, trainer, pl_module, outputs=None):

        global_rank = pl_module.global_rank

        if global_rank == 0:
            # Collect data from all batches
            for key in self.val_data.keys():
                self.val_data[key] = self.val_data[key].cpu()

            sorted_idxs = list(torch.sort(self.val_data["dist"])[1].numpy())
            t = len(sorted_idxs)
            idxs = [sorted_idxs[0], 
                        sorted_idxs[t//8], 
                        sorted_idxs[t//4], 
                        sorted_idxs[t//2], 
                        sorted_idxs[3*t//4],
                        sorted_idxs[5*t//8],
                        sorted_idxs[-1]]

            logger = get_wandb_logger(trainer=trainer)
            experiment = logger.experiment

            fig = plot_pcs(self.val_data["source"], self.val_data["target"], idxs, self.val_data["preds"])
            experiment.log({'val/val_pred':wandb.Image(fig)})
            fig.savefig(os.path.join(self.output_dir, "val_pred_{}.pdf".format(pl_module.current_epoch)))

            fig = plot_corr(self.val_data["dist"], self.val_data["preds"])
            experiment.log({'val/val_corr':wandb.Image(fig)})
            fig.savefig(os.path.join(self.output_dir, "val_corr_{}.pdf".format(pl_module.current_epoch)))

            cos_sim_train = torch.nn.functional.cosine_similarity(self.val_data["true_grad"], self.val_data["pred_grad"], dim=2, eps=1e-08)
            fig, ax = plt.subplots()
            ax.hist(cos_sim_train.flatten().numpy(), bins=100, density=True);
            ax.set_xlabel("Cosine Similarity", fontsize=14)
            ax.set_ylabel("Density", fontsize=14)
            fig.tight_layout()
            experiment.log({'val/val_grads_cosine_dis':wandb.Image(fig)})
            fig.savefig(os.path.join(self.output_dir, "val_grads_cosine_dis_{}.pdf".format(pl_module.current_epoch)))

            # validate for generalization
            self.val_data["source"] = self.val_data["source"]#.to(pl_module.device)

            ## pure noise
            scale = 0.1 + torch.rand(self.val_data["source"].shape[0])[..., None, None]
            target = scale*torch.randn_like(self.val_data["source"])#.to(pl_module.device)
            with torch.set_grad_enabled(True):
                self.test_generalization(self.val_data["source"], target, pl_module, experiment, batch_size=self.val_batch_size , label='val_N')

            ## source + noise
            scale = 2*torch.rand(self.val_data["source"].shape[0])[..., None, None]
            target = self.val_data["source"] + scale*torch.randn_like(self.val_data["source"])#.to(self.val_data["source"].device)
            with torch.set_grad_enabled(True):
                self.test_generalization(self.val_data["source"], target, pl_module, experiment, batch_size=self.val_batch_size , label='val_S+N')

            ## target + noise
            scale = 2*torch.rand(self.val_data["source"].shape[0])[..., None, None]
            target = self.val_data["target"] + scale*torch.randn_like(self.val_data["target"])
            #target = target.to(self.val_data["source"].device)
            with torch.set_grad_enabled(True):
                self.test_generalization(self.val_data["source"], target, pl_module, experiment, batch_size=self.val_batch_size , label='val_T+N')

            self.reset(mode='val')

    def on_test_batch_end(self, trainer, pl_module, output, batch, batch_idx, dataloader_idx):
        global_rank = pl_module.global_rank

        if global_rank == 0:
            with torch.set_grad_enabled(True):
                # Compute loss and grad from algorithm
                temp_target = output['target'].detach().clone() 
                temp_target.requires_grad = True
                temp_target.retain_grad()

                preds = hungarian(output['source'], temp_target)
                output['true_dist'] = preds.clone().detach()
                loss = preds.mean()
                loss.backward()
                output["true_grad"] = temp_target.grad

                # Compute loss and grad from model
                temp_target = output['target'].detach().clone() 
                temp_target.requires_grad = True
                temp_target.retain_grad()

                preds = pl_module(output['source'], temp_target)
                loss = preds.mean()
                loss.backward()
                output["pred_grad"] = temp_target.grad

                # pl_module.zero_grad()

                # del temp_target, 

            
            for key in output.keys():
                if key in self.test_data:
                    self.test_data[key] = torch.cat((self.test_data[key], output[key].cpu()), dim=0)
                else:
                    self.test_data[key] = output[key].cpu()

    def on_test_epoch_end(self, trainer, pl_module, outputs=None):

        global_rank = pl_module.global_rank

        if global_rank == 0:
            # Collect data from all batches
            for key in self.test_data.keys():
                self.test_data[key] = self.test_data[key].cpu()

            sorted_idxs = list(torch.sort(self.test_data["dist"])[1].numpy())
            t = len(sorted_idxs)
            idxs = [sorted_idxs[0], 
                        sorted_idxs[t//8], 
                        sorted_idxs[t//4], 
                        sorted_idxs[t//2], 
                        sorted_idxs[3*t//4],
                        sorted_idxs[5*t//8],
                        sorted_idxs[-1]]

            logger = get_wandb_logger(trainer=trainer)
            experiment = logger.experiment

            fig = plot_pcs(self.test_data["source"], self.test_data["target"], idxs, self.test_data["preds"])
            experiment.log({'test/test_pred':wandb.Image(fig)})
            fig.savefig(os.path.join(self.output_dir, "test_pred_{}.pdf".format(pl_module.current_epoch)))

            fig = plot_corr(self.test_data["dist"], self.test_data["preds"])
            experiment.log({'test/test_corr':wandb.Image(fig)})
            fig.savefig(os.path.join(self.output_dir, "test_corr_{}.pdf".format(pl_module.current_epoch)))

            cos_sim_train = torch.nn.functional.cosine_similarity(self.test_data["true_grad"], self.test_data["pred_grad"], dim=2, eps=1e-08)
            fig, ax = plt.subplots()
            ax.hist(cos_sim_train.flatten().numpy(), bins=100, density=True);
            ax.set_xlabel("Cosine Similarity", fontsize=14)
            ax.set_ylabel("Density", fontsize=14)
            fig.tight_layout()
            experiment.log({'test/test_grads_cosine_dis':wandb.Image(fig)})
            fig.savefig(os.path.join(self.output_dir, "test_grads_cosine_dis_{}.pdf".format(pl_module.current_epoch)))

            # validate for generalization
            self.test_data["source"] = self.test_data["source"]#.to(pl_module.device)

            ## pure noise
            scale = 0.1 + 1.25*torch.rand(self.test_data["source"].shape[0])[..., None, None]
            target = scale*torch.randn_like(self.test_data["source"])#.to(pl_module.device)
            with torch.set_grad_enabled(True):
                self.test_generalization(self.test_data["source"], target, pl_module, experiment, batch_size=self.val_batch_size , label='test_N')

            ## source + noise
            scale = 2.5*torch.rand(self.test_data["source"].shape[0])[..., None, None]
            target = self.test_data["source"] + scale*torch.randn_like(self.test_data["source"])#.to(self.val_data["source"].device)
            with torch.set_grad_enabled(True):
                self.test_generalization(self.test_data["source"], target, pl_module, experiment, batch_size=self.val_batch_size , label='test_S+N')

            ## target + noise
            scale = 2.5*torch.rand(self.test_data["source"].shape[0])[..., None, None]
            target = self.test_data["target"] + scale*torch.randn_like(self.test_data["target"])
            #target = target.to(self.val_data["source"].device)
            with torch.set_grad_enabled(True):
                self.test_generalization(self.test_data["source"], target, pl_module, experiment, batch_size=self.val_batch_size , label='test_T+N')

            self.reset(mode='test')

    def test_generalization(self, source, target, pl_module, experiment, batch_size, label):

        true_dists = []
        pred_dists = []
        cos_sims = []

        for idx in range(0, source.shape[0], batch_size):
            source_batch = source[idx:idx+batch_size].to(pl_module.device)
            target_batch = target[idx:idx+batch_size].to(pl_module.device)

            # Compute loss and grad from algorithm
            temp_target = target_batch.detach().clone() 
            temp_target.requires_grad = True
            temp_target.retain_grad()

            true_dist = hungarian(source_batch, temp_target)
            loss = true_dist.mean()
            loss.backward()
            true_grad = temp_target.grad

            # Compute loss and grad from model
            temp_target = target_batch.detach().clone() 
            temp_target.requires_grad = True
            temp_target.retain_grad()

            pred_dist = pl_module(source_batch, temp_target)
            loss = pred_dist.mean()
            loss.backward()
            pl_module.zero_grad()
            pred_grad = temp_target.grad

            cos_sim = torch.nn.functional.cosine_similarity(true_grad, pred_grad, dim=2, eps=1e-08)

            true_dists.append(true_dist.detach().cpu())
            pred_dists.append(pred_dist.detach().cpu())
            cos_sims.append(cos_sim.detach().cpu())

        true_dists = torch.cat(true_dists, dim=0)
        pred_dists = torch.cat(pred_dists, dim=0)
        cos_sim = torch.cat(cos_sims, dim=0)

        fig = plot_corr(true_dists, pred_dists)
        experiment.log({f'gen/{label}_corr':wandb.Image(fig)})
        fig.savefig(os.path.join(self.output_dir, f"{label}_corr_{{}}.pdf".format(pl_module.current_epoch)))

        fig, ax = plt.subplots()
        ax.hist(cos_sim.flatten().numpy(), bins=100, density=True);
        ax.set_xlabel("Cosine Similarity", fontsize=14)
        ax.set_ylabel("Density", fontsize=14)
        fig.tight_layout()
        experiment.log({f'gen/{label}_grads_cosine_dis':wandb.Image(fig)})
        fig.savefig(os.path.join(self.output_dir, f"{label}_grads_cosine_dis_{{}}.pdf".format(pl_module.current_epoch)))

    def reset(self, mode=None):

        if mode is None or mode == "train":
            self.train_data = {}
        elif mode == "val":
            self.val_data = {}
        else:
            self.test_data = {}