import matplotlib.pyplot as plt
import torch
import torchvision.utils
from torch.utils.data import DataLoader
from tqdm import tqdm

import pdb

from two_step_zoo.utils import count_parameters

class BaseClusteringTrainer:
    """
    Base class for training a clustering two-step module
    """
    def __init__(
            self,
            module,
            child_trainers,
            evaluator,
            writer,
            train_loader,
            valid_loader,
            test_loader,

            checkpoint_load_list,
            memory_efficient,
            
            only_test=False,
            cluster_cfg=None
    ):
        self.module = module
        self.child_trainers = child_trainers
        self.writer = writer
        self.evaluator = evaluator
        self.only_test = only_test

        self.train_loader = train_loader
        self.valid_loader = valid_loader
        self.test_loader = test_loader

        self.memory_efficient = memory_efficient

        self.cluster_cfg = cluster_cfg

        for trainer in child_trainers:
            if hasattr(trainer.module, "whitening_transform") and trainer.module.whitening_transform:
                trainer.whitening_loader = train_loader

        if self.memory_efficient:
            self.module.trainers = child_trainers

        self.load_checkpoint(checkpoint_load_list)

    def train(self):
        raise NotImplementedError("Implement train function in child classes")

    def write_checkpoint(self, tag):
        for trainer in self.child_trainers:
            trainer.write_checkpoint(tag)

    def load_checkpoint(self, checkpoint_load_list):
    
        for trainer in self.child_trainers:

            if self.memory_efficient:
                self.module.switch_component(trainer)

            for ckpt in checkpoint_load_list:
                try:
                    trainer.load_checkpoint(ckpt)
                    break
                except FileNotFoundError:
                    print(f"Did not find {ckpt} {trainer.module.module_id} checkpoint")

    def sample_and_record(self, epoch):
        NUM_SAMPLES = 64
        GRID_ROWS = 8

        with torch.no_grad():
            imgs = self.evaluator.module.sample(NUM_SAMPLES)
            imgs.clamp_(self.child_trainers[0].module.data_min, self.child_trainers[0].module.data_max)
            grid = torchvision.utils.make_grid(imgs, nrow=GRID_ROWS, pad_value=1, normalize=True, scale_each=True)
            grid_permuted = grid.permute((1,2,0))

            plt.figure()
            plt.axis("off")
            plt.imshow(grid_permuted.detach().cpu().numpy())

            self.writer.write_image(self.module.module_id+"/samples", grid, global_step=epoch)

    def record_dict(self, tag_prefix, value_dict, save=False):
        for k, v in value_dict.items():
            print(f"clusterer {k}: {v:.4f}")
            self.writer.write_scalar(f"cluster_{tag_prefix}_{k}", v, 0)

        if save:
            self.writer.write_json(
                f"cluster_{tag_prefix}_metrics",
                {k: v.item() for k, v in value_dict.items()}
            )

class DisjointSequentialClusteringTrainer(BaseClusteringTrainer):
    """Class for fully training a clustering model cluster-by-cluster"""
    def train(self):
        
        if not self.only_test:
            for trainer in self.child_trainers:
                self.train_component(trainer)

        try:
            self.sample_and_record(self.child_trainers[0].epoch)
        except AttributeError:
            print("No sample method available")
        
        test_results = self.evaluator.test()
        self.record_dict("test", test_results, save=True)
        self.module.cleanup()
    
    def train_component(self, trainer):
        if not self.memory_efficient:
            trainer.train()
        else:
            self.module.switch_component(trainer)
            trainer.train()


# TODO: proper overloading of dataset to take advantage of num_workers etc.
class MixedDataloader():
    def __init__(self, dataloaders):
        self.dataloaders = dataloaders
        self.reload_iters()

    def reload_iters(self):
        self.data_iters = [ 
            iter(dataloader) for dataloader in self.dataloaders
        ]
    
    def __len__(self):
        return min([len(dataloader) for dataloader in self.dataloaders])

    def __iter__(self): return self 

    def __next__(self):
        return [next(data_iter) for data_iter in self.data_iters]


class MixedCluster(BaseClusteringTrainer):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.shared_step_after = self.cluster_cfg["shared_step_after"]

    def _validate(self):
        for trainer in self.child_trainers:
            trainer.module.eval()

        valid_results = self.evaluator.validate()
        self.record_dict("validate", valid_results, self.epoch)
        return valid_results.get(self.early_stopping_metric)

    def train(self):
        # TODO: refactor into this class
        self.epoch = self.child_trainers[0].epoch 
        self.max_epochs = self.child_trainers[0].max_epochs 
        self.bad_valid_epochs = self.child_trainers[0].bad_valid_epochs 
        self.max_bad_valid_epochs = self.child_trainers[0].max_bad_valid_epochs 
        self.early_stopping_metric = self.child_trainers[0].early_stopping_metric

        # Build mixed data loader
        self.mixed_cluster_train_loader = MixedDataloader([trainer.train_loader for trainer in self.child_trainers])

        # Log number of trainable params
        for trainer in self.child_trainers:
            if trainer.module.is_cluster_component:
                trainer.write_combined_scalar("num_params", count_parameters(trainer.module), cidx=trainer.module.cluster_component)
            else:
                trainer.write_scalar("num_params", count_parameters(trainer.module), step=0)

            trainer.update_transform_parameters()
     
        while self.epoch < self.max_epochs and self.bad_valid_epochs < self.max_bad_valid_epochs:
            for trainer in self.child_trainers:
                trainer.module.train()
           
            self.train_for_epoch()

            valid_loss = self._validate()

            if self.early_stopping_metric:
                if valid_loss < self.best_valid_loss:
                    self.bad_valid_epochs = 0
                    self.best_valid_loss = valid_loss
                    self.write_checkpoint("best_valid")

                    print(f"Best validation loss of {valid_loss} achieved on epoch {self.epoch}")

                else:
                    self.bad_valid_epochs += 1

                    if self.bad_valid_epochs == self.max_bad_valid_epochs:
                        print(f"No validation improvement for {self.max_bad_valid_epochs}"
                                + " epochs. Training halted.")
                        self.write_checkpoint("latest")

                        self.load_checkpoint("best_valid")
                        self._test()

                        return

            for trainer in self.child_trainers:
                trainer.write_checkpoint("latest")
            
            if self.epoch % self.child_trainers[0].epoch_sample_every == 0: # If image data
                #TODO: refactor
                is_training = self.module.training
                self.module.eval()
                self.sample_and_record(self.epoch)
                self.module.train(is_training)

        self.sample_and_record(self.epoch)
        test_results = self.evaluator.test()
        self.record_dict("test", test_results, save=True)
        self.module.cleanup()
    
    def _tqdm_progress_bar(self, iterable, desc, length, leave):
        return tqdm(
            iterable,
            desc=desc,
            total=length,
            bar_format="{desc}[{n_fmt}/{total_fmt}] {percentage:3.0f}%|{bar}{postfix} [{elapsed}<{remaining}]",
            leave=leave
        )
    
    def train_base(self, pbar, full_loss_dict):
        for j, batch in pbar:
            for cidx, trainer in enumerate(self.child_trainers):

                # try:
                # grad = trainer.module.encoder.cnn_layers[0].module.weight.grad
                # print("Before grad: ",  grad.abs().sum() if grad is not None else "None")
                # # except:
                # #     pdb.set_trace()
                # unique_grad = trainer.module.encoder.cnn_layers[0].shared_module.module.weight.grad
                # print("After custom grad: ",  unique_grad.abs().sum() if unique_grad is not None else "None")

                loss_dict = trainer.train_single_batch(batch[cidx][0])

                # unique_grad = trainer.module.encoder.cnn_layers[0].shared_module.module.weight.grad
                # print("After custom grad: ",  unique_grad.abs().sum() if unique_grad is not None else "None")

                if not self.shared_step_after:
                    self.module.shared_step()

                # grad = trainer.module.encoder.cnn_layers[0].module.weight.grad
                # print("After grad: ",  grad.abs().sum() if grad is not None else "None")

                # unique_grad = trainer.module.encoder.cnn_layers[0].shared_module.module.weight.grad
                # print("After after custom grad: ",  unique_grad.abs().sum() if unique_grad is not None else "None")

                # pdb.set_trace()

                if j == 0:
                    full_loss_dict[cidx] = loss_dict
                else:
                    for k in loss_dict.keys():
                        full_loss_dict[cidx][k] += loss_dict[k]

            if self.shared_step_after:
                self.module.shared_step()
                
            self.j = j
        
        return full_loss_dict

    def train_multiopt(self, pbar, full_loss_dict):

        num_disc_per_gen = 1
        if hasattr(self.child_trainers[0].module, 'num_discriminator_steps'):
            num_disc_per_gen = self.child_trainers[0].module.num_discriminator_steps
        
        for j, batch in pbar:
            num_module_steps = num_disc_per_gen+1
            #TODO: does effect this has on lr schedule matter?
            for batch_forward_pass_step in range(num_module_steps): # TODO: does it matter the gen/disc always see same batch order?
                for cidx, trainer in enumerate(self.child_trainers):
                    
                    module_to_train = "generator" if batch_forward_pass_step == num_disc_per_gen else "discriminator"
                    loss_dict = trainer.train_single_batch(batch[cidx][0], shared_optimizers=self.module.shared_module.optimizer, module_to_train=module_to_train)

                    if j == 0:
                        full_loss_dict[cidx] = loss_dict
                    else:
                        for k in loss_dict.keys():
                            full_loss_dict[cidx][k] += loss_dict[k]

                    if not self.shared_step_after:
                        self.module.shared_step()

                if self.shared_step_after:
                    self.module.shared_step()
                self.j = j

        return full_loss_dict

    def train_for_epoch(self):
        
        self.mixed_cluster_train_loader.reload_iters()
        pbar = self._tqdm_progress_bar(
            iterable=enumerate(self.mixed_cluster_train_loader),
            desc="Training",
            length=len(self.mixed_cluster_train_loader),
            leave=True
        )

        full_loss_dict = {}

        if hasattr(self.child_trainers[0].module, 'discriminator'): #TODO: make this a config param
            assert self.child_trainers[0].module.pass_optimizer, "discriminator training requires seperate optimizer"
            train_fn = self.train_multiopt
        else:
            train_fn = self.train_base

        full_loss_dict = train_fn(pbar, full_loss_dict)

        if len(self.child_trainers[0].module.data_shape) > 1 and self.child_trainers[0].epoch % self.child_trainers[0].epoch_sample_every == 0: # If image data
            for cidx, trainer in enumerate(self.child_trainers):
                is_training = trainer.module.training
                trainer.module.eval()
                trainer.sample_and_record()
                trainer.module.train(is_training)

        self.epoch += 1

        for trainer in self.child_trainers:
            trainer.update_transform_parameters()
            trainer.epoch += 1
            trainer.write_checkpoint("latest")
        
        for module_idx, loss_dict in full_loss_dict.items():
            for k, v in full_loss_dict[module_idx].items():
                print(f"{self.child_trainers[module_idx].module.module_id} {k}: {v/self.j:.4f} after {self.epoch} epochs")
