import torch
from pytorch_lightning.callbacks import Callback

import os
from src.utils import get_wandb_logger

import sys
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', '..', 'src'))
from utils.viz import plot_pcs_recon, plot_corr
from utils.distance import hungarian, chamfer
import wandb
import matplotlib.pyplot as plt

class GenEvalCallback(Callback):
    def __init__(self, output_dir, dist_fn_type, train_log_freq=50):
        super(GenEvalCallback, self).__init__()
        
        self.train_data = {}
        self.val_data = {}
        self.test_data = {}
        self.output_dir = output_dir
        self.dist_fn_type = dist_fn_type
        self.train_log_freq = train_log_freq

        if dist_fn_type == 'chamfer':
            self.dist_fn = chamfer
        elif dist_fn_type == 'hungarian':
            self.dist_fn = hungarian
        else:
            raise NotImplementedError

        os.makedirs(self.output_dir, exist_ok=True)
        
    def on_train_batch_end(self, trainer, pl_module, output, batch, batch_idx, dataloader_idx):
        global_rank = pl_module.global_rank

        if (pl_module.current_epoch % self.train_log_freq == 0 or pl_module.current_epoch < 2) and (len(self.train_data) == 0 or len(self.train_data['source']) < 250):
            if global_rank == 0:
                if output['target'].grad is not None:
                    output["pred_grad"] = output['target'].grad
                    output['target'].grad = None
                    # device = next(pl_module.parameters()).device
                    # output['pred_dist'] = pl_module.criterion(output["source"].to(device), None, output["target"].to(device), None, False).cpu()

                    # Compute loss and grad from algorithm
                    temp_target = output['target'].detach().clone() 
                    temp_target.requires_grad = True
                    temp_target.retain_grad()

                    preds = self.dist_fn(output['source'], temp_target)
                    output['true_dist'] = preds.clone().detach()
                    loss = preds.mean()
                    loss.backward()
                    output["true_grad"] = temp_target.grad

                for key in output.keys():
                    if key != "loss":
                        if key in self.train_data:
                            self.train_data[key] = torch.cat((self.train_data[key], output[key].detach()), dim=0)
                        else:
                            self.train_data[key] = output[key].detach()

    def on_validation_batch_end(self, trainer, pl_module, output, batch, batch_idx, dataloader_idx):
        global_rank = pl_module.global_rank

        if global_rank == 0:
            for key in output.keys():
                if key in self.val_data:
                    self.val_data[key] = torch.cat((self.val_data[key], output[key]), dim=0)
                else:
                    self.val_data[key] = output[key]

    def on_train_epoch_end(self, trainer, pl_module, outputs=None):

        global_rank = pl_module.global_rank

        # if pl_module.current_epoch % 1 == 0:
        if pl_module.current_epoch % self.train_log_freq == 0 or pl_module.current_epoch < 2:
            if global_rank == 0:
                # Collect data from all batches
                for key in self.train_data.keys():
                    self.train_data[key] = self.train_data[key].cpu()

                logger = get_wandb_logger(trainer=trainer)
                experiment = logger.experiment

                idxs = list(range(8))

                fig = plot_pcs_recon(self.train_data["source"], self.train_data["target"], idxs, pl_module.criterion, self.dist_fn_type)
                experiment.log({'train_recons':wandb.Image(fig)})

                fig = plot_corr(self.train_data["true_dist"], self.train_data["dists"])
                experiment.log({'train_corr':wandb.Image(fig)})

                cos_sim_train = torch.nn.functional.cosine_similarity(self.train_data["true_grad"], self.train_data["pred_grad"], dim=2, eps=1e-08)
                fig, ax = plt.subplots()
                ax.hist(cos_sim_train.flatten().numpy(), bins=100, density=True);
                fig.tight_layout()
                experiment.log({'grads_cosine_dis':wandb.Image(fig)})

            self.reset('train')

    def on_validation_epoch_end(self, trainer, pl_module, outputs=None):

        global_rank = pl_module.global_rank

        if global_rank == 0:
            # Collect data from all batches
            for key in self.val_data.keys():
                self.val_data[key] = self.val_data[key].cpu()

            logger = get_wandb_logger(trainer=trainer)
            experiment = logger.experiment

            idxs = list(range(8))

            fig = plot_pcs_recon(self.val_data["source"], self.val_data["target"], idxs, pl_module.criterion, self.dist_fn_type)
            experiment.log({'val_recons':wandb.Image(fig)})

            # device = next(pl_module.parameters()).device
            # d_approx = pl_module.criterion(self.val_data["source"].to(device), None, self.val_data["target"].to(device), None, False).cpu()
            
            d = self.dist_fn(self.val_data["source"], self.val_data["target"])
            # d = chamfer(self.val_data["source"], self.val_data["target"])

            fig = plot_corr(d, self.val_data["dists"])
            experiment.log({'val_corr':wandb.Image(fig)})

            self.reset('val')

    def on_test_batch_end(self, trainer, pl_module, output, batch, batch_idx, dataloader_idx):
        global_rank = pl_module.global_rank
        if global_rank == 0:
            for key in output.keys():
                if key in self.test_data:
                    self.test_data[key] = torch.cat((self.test_data[key], output[key]), dim=0)
                else:
                    self.test_data[key] = output[key]

    def on_test_epoch_end(self, trainer, pl_module, outputs=None):
        global_rank = pl_module.global_rank
        if global_rank == 0:
            # Collect data from all batches
            for key in self.test_data.keys():
                self.test_data[key] = self.test_data[key].cpu()

            logger = get_wandb_logger(trainer=trainer)
            experiment = logger.experiment

            idxs = list(range(8))

            fig = plot_pcs_recon(self.test_data["source"], self.test_data["target"], idxs, pl_module.criterion, self.dist_fn_type)
            experiment.log({'test_recons':wandb.Image(fig)})

            # device = next(pl_module.parameters()).device
            # d_approx = pl_module.criterion(self.val_data["source"].to(device), None, self.val_data["target"].to(device), None, False).cpu()
            
            d = self.dist_fn(self.test_data["source"], self.test_data["target"])
            d_chamfer = chamfer(self.test_data["source"], self.test_data["target"])

            torch.save(d, os.path.join(self.output_dir, 'test_dists_emd.pt'))
            torch.save(d, os.path.join(self.output_dir, 'test_dists_chamfer.pt'))

            fig = plot_corr(d, self.test_data["dists"])
            fig.savefig(os.path.join(self.output_dir, 'test_corr.png'))
            experiment.log({'test_corr':wandb.Image(fig)})

            self.reset('test')

    def reset(self, mode):
        if mode == 'train':
            print('resetting train data')
            self.train_data = {}
        elif mode == 'test':
            print('resetting test data')
            self.test_data = {}
        else:
            print('resetting val data')
            self.val_data = {}