import math
import time
from tqdm import tqdm
from termcolor import colored

import torch

from .adversarys_assistant import AdversarysAssistant
from .proxy import (GanBEProxy, GanLSProxy, GanWCritic, GanClassifierProxy,
                    BiGanClassifierProxy, BiGanWCritic)
from .utils import check_generators_valid
from .objectives.utils import ObjectiveType
from .utils import NormalizeType


class TrainBase():

    def __init__(self, datasets, generator_optimizer, objective,
                 normtype=NormalizeType.Standard, device='cpu'):

        if device == 'cuda' and torch.cuda.is_available():
            self.device = torch.device('cuda')
        else:
            self.device = torch.device('cpu')

        self.datasets = datasets
        self.objective = objective
        self.generator_optimizer = generator_optimizer
        self.num_params = 0
        self.running_time = 0.
        self.iteration = 0
        self.generator = None
        if not isinstance(normtype, NormalizeType):
            raise ValueError("Normalize type must be of type \"NormalizeType\""
                             " - see utils.py ")
        self.normtype = normtype
        self._resumed = False

    def _train_proxy(self, num_iterations):
        raise NotImplementedError()

    def _objective_loss(self, data):
        raise NotImplementedError()

    def _wandb_update_misc(self, iterations_per_epoch, loss=None,
                           proxy_loss=None, part_losses=None,
                           advas_normalizer=1.):
        raise NotImplementedError()

    def _make_dataiterators(self):
        raise NotImplementedError()

    def _normalized_backward(self, part_losses, retain_first_graph=False):
        raise NotImplementedError()

    def _normalized_advas_backward(self, part_losses,
                                   retain_first_graph=False):
        raise NotImplementedError()

    def _get_iterations_per_epoch(self):
        raise NotImplementedError()

    def train(self, num_iterations={'generator': 100, 'proxy': 500},
              train_proxy_every=10):

        total_iterations = num_iterations['generator']

        stop = False
        iterations_per_epoch = self._get_iterations_per_epoch()
        self._wandb_update_misc(iterations_per_epoch, loss=None,
                                proxy_loss=None, part_losses=None,
                                advas_normalizer=1.)
        with tqdm(desc=f"Generative Modeling (N_parameters: {self.num_params})",
                  total=total_iterations, ascii=True) as pbar:
            if self._resumed:
                pbar.update(self.iteration)
                self._resumed = False
            while not stop:
                data_iter_train, _ = self._make_dataiterators()

                for _, data in enumerate(data_iter_train):
                    self._train()
                    start_time = time.time()
                    # update proxy
                    if self.iteration % train_proxy_every == 0:
                        proxy_loss = self._train_proxy(num_iterations['proxy'])

                    # calculate loss
                    self.generator_optimizer.zero_grad()
                    loss, part_losses = self._objective_loss(data)
                    advas = AdversarysAssistant(1.)
                    if self.normtype == NormalizeType.Total:
                        advas.normalized_backward(self.generator.parameters(),
                                                  *part_losses)
                    elif self.normtype == NormalizeType.Advas:
                        advas.normalized_advas_backward(self.generator.parameters(),
                                                        *part_losses)
                    else:
                        loss.backward()
                    self.generator_optimizer.step()
                    self.running_time += time.time() - start_time

                    self._wandb_update_misc(iterations_per_epoch, loss,
                                            proxy_loss, part_losses=part_losses,
                                            advas_normalizer=advas.normalizer)
                    pbar.update(1)
                    self.iteration += 1
                    # stop training
                    if self.iteration >= total_iterations:
                        stop = True
                        break

    def _do_metric(self, iterations_per_epoch):
        iteration = self.iteration
        do_large_metric = (iteration % (5*iterations_per_epoch) == 0) if iteration > 0 else False
        if iteration < iterations_per_epoch:
            do_small_metric = iteration <= 1 or (math.log2(iteration) % 1 == 0)
        else:
            do_small_metric = False
        return do_small_metric or (iteration % iterations_per_epoch == 0), do_large_metric

    def _eval(self):
        raise NotImplementedError()

    def _train(self):
        raise NotImplementedError()


class TrainGan(TrainBase):

    def __init__(self, generator, datasets, generator_optimizer, batch_size,
                 proxy_model, proxy_optimizer, proxy_batch_size, objective,
                 num_workers=0, wandbwrapper=None, **kwargs):
        super(TrainGan, self).__init__(datasets, generator_optimizer, objective,
                                       **kwargs)

        self.num_params += sum(params.numel()
                               for params in generator.parameters())

        self.generator = generator.to(device=self.device, non_blocking=True)
        # use all workers in the proxy training
        self.num_workers_per_dataset = num_workers
        self.batch_size = batch_size
        self.proxy_model = proxy_model
        self.wandbwrapper = wandbwrapper

        self.dataset_is_image = getattr(datasets[0], "is_image", False)

        proxy_args = [datasets, proxy_model, proxy_optimizer,
                      proxy_batch_size, self.num_workers_per_dataset,
                      kwargs['device']]
        if self.objective.obj_type == ObjectiveType.JS:
            self.proxy = GanClassifierProxy(*proxy_args)
        elif self.objective.obj_type == ObjectiveType.Wasserstein:
            reg_args = [self.objective.GP_strength, self.objective.clamp_limit,
                        self.objective.weight_norm]
            proxy_args = reg_args + proxy_args
            self.proxy = GanWCritic(*proxy_args)
        elif self.objective.obj_type == ObjectiveType.LS:
            reg_args = [self.objective.a, self.objective.b]
            proxy_args = reg_args + proxy_args
            self.proxy = GanLSProxy(*proxy_args)
        elif self.objective.obj_type == ObjectiveType.BE:
            from .models.proxies import ProxyBEGAN
            if type(proxy_model) is not ProxyBEGAN:
                raise ValueError("Must use BE proxy model")
            reg_args = [self.objective.lambda_k, self.objective.gamma]
            proxy_args = reg_args + proxy_args
            self.proxy = GanBEProxy(*proxy_args)
        else:
            raise NotImplementedError("This Gan objective is not implemented")

    def _train(self):
        self.generator.train()

    def _eval(self):
        self.generator.eval()

    def _get_iterations_per_epoch(self):
        return len(self.proxy.dataloader_train)

    def train(self, *args, **kwargs):
        super(TrainGan, self).train(*args, **kwargs)
        return self.generator, self.proxy_model

    def _train_proxy(self, num_iterations):
        annealing = getattr(self.datasets[0], "annealing", False)
        if annealing:
            self.datasets[0].update_annealing()

        # train proxy
        return self.proxy.train(self.generator, num_iterations=num_iterations)

    def _objective_loss(self, data):
        """ We perform standard GAN training

        Input:
            data: X sampled from generator

        Returns:
            loss: GAN loss

        """
        X = data
        losses = self.objective.loss(X, self.proxy, self.generator)
        total_loss, original_loss, regularizer = losses
        return total_loss, (original_loss, regularizer)

    @staticmethod
    def _sample_from_generator(generator, batch_size):
        sample_shape = torch.Size([batch_size])
        fake_X = generator.sample(sample_shape)
        return fake_X

    def _make_dataiterators(self):
        dataset_size = len(self.datasets[0])
        data_iterator = (self._sample_from_generator(self.generator,
                                                     self.batch_size)
                         for _ in range(dataset_size))
        return data_iterator, None

    def _wandb_update_misc(self, iterations_per_epoch, loss=None,
                           proxy_loss=None, part_losses=None,
                           advas_normalizer=1.):
        self._eval()
        updated_metrics = False
        iteration = self.iteration
        with torch.no_grad():
            # do quick metric calcuations logarithmically
            do_small_metric, do_large_metric = self._do_metric(iterations_per_epoch)
            if do_small_metric:
                self.wandbwrapper.fid_score(self.generator, self.datasets[0],
                                            N=int(1e3), label='small')
                self.wandbwrapper.inception_score(self.generator,
                                                  N=int(1e3), label='small')
                self.wandbwrapper.swd_metric(self.generator, self.datasets[0],
                                             N=int(1e3), label='small')

                updated_metrics = True
            if do_large_metric:
                self.wandbwrapper.fid_score(self.generator, self.datasets[0],
                                            N=len(self.datasets[0]),
                                            label='large')
                self.wandbwrapper.inception_score(self.generator,
                                                  label='large')
                self.wandbwrapper.swd_metric(self.generator, self.datasets[0],
                                             N=16384, label='large')

                updated_metrics = True
            if updated_metrics or (iteration % 500 == 0):
                self.wandbwrapper.track_summary_stats(self.generator,
                                                      self.datasets[0])
                images, image_names = self._get_images_names()
                self.wandbwrapper.add_images(images, image_names,
                                             iteration=iteration)
                objects_to_save, objects_name = self.get_objects_to_save()
                self.wandbwrapper.save_objects(objects_to_save
                                               + [iteration,
                                                  self.running_time],
                                               objects_name
                                               + ['iteration', 'running_time'])
                updated_metrics = True
            if updated_metrics or (iteration % 50 == 0):
                if loss is not None:
                    self.wandbwrapper.track_loss(loss, 'generator',
                                                 part_losses=part_losses)
                if proxy_loss is not None:
                    self.wandbwrapper.track_loss(proxy_loss, 'proxy')
                self.wandbwrapper.wandb.log({'advas_normalizer':
                                             advas_normalizer},
                                            commit=False)
                updated_metrics = True
        if updated_metrics:
            self.wandbwrapper.log(iteration, self.running_time)

    def get_objects_to_save(self):
        objects_to_save = [self.generator, self.proxy_model,
                           self.generator_optimizer, self.proxy.optimizer]
        objects_name = ['generator', 'proxy', 'generator_optimizer',
                        'proxy_optimizer']
        return objects_to_save, objects_name

    def load_objects(self, loaded_objects=None):
        if loaded_objects is not None:
            self.running_time = loaded_objects['running_time']
            self.iteration = loaded_objects['iteration']
            self.generator.load_state_dict(loaded_objects['generator'])
            self.proxy.proxy_model.load_state_dict(loaded_objects['proxy'])
            self.generator_optimizer.load_state_dict(loaded_objects['generator_optimizer'])
            self.proxy.optimizer.load_state_dict(loaded_objects['proxy_optimizer'])
            self._resumed = True

    def _get_images_names(self):
        self._eval()
        fake_images = self.generator.sample(torch.Size([3])).cpu()
        if getattr(self.generator, "get_fixed", None) is not None:
            fake_fixed_images = self.generator.get_fixed().cpu()
            n_fixed = fake_fixed_images.size(0)
            images = [fake_images, fake_fixed_images]
        else:
            images = [fake_images]
            n_fixed = 0
        idx = torch.randint(len(self.datasets[0]), (1,)).item()
        real, _ = self.datasets[0].__getitem__(idx)
        images = torch.cat(images+[real.unsqueeze(0)], dim=0)

        return images, ['fake']*3 + ['fake_fixed']*n_fixed + ['real']
