import os
import pickle
import scipy
import torch
import time
import math
from pathlib import Path
from enum import Enum
from functools import wraps
from torch.utils.data import DataLoader
from torchvision import transforms

from .swd import SWD
from .metrics import inception_score, fid_score


def subtract_time(fn, *args, **kwargs):
    @wraps(fn)
    def wrapper(*args, **kwargs):
        wandb_object = args[0]
        timer = time.time()
        result = fn(*args, **kwargs)
        time_spend = time.time() - timer
        wandb_object.time_to_subtract += time_spend
        return result
    return wrapper


def limit_n_decorator(fn, *args, **kwargs):
    @wraps(fn)
    def wrapper(*args, **kwargs):
        wandb_object = args[0]
        limit = wandb_object.N_limit
        if limit is not None:
            if 'N' in kwargs and kwargs['N'] > limit:
                kwargs['N'] = limit
            elif fn.__defaults__[0] > limit:
                kwargs['N'] = limit
        result = fn(*args, **kwargs)
        return result
    return wrapper


class WandbWrapper:

    def __init__(self, wandb, batch_size, num_workers, device='cpu',
                 normalize_metric_input=True, N_limit=None, swd_limit=None):
        """
        INPUT:

            normalize_metric_input (boolean): set to True IF the input your
                                              generated images are in the
                                              range [0, 1]. If your generated
                                              images are between [-1, 1] set to
                                              False.

        """
        self._normalize_metric_input = normalize_metric_input
        self.wandb = wandb
        if device == 'cuda' and torch.cuda.is_available():
            self.device = torch.device('cuda')
        else:
            self.device = torch.device('cpu')
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.N_limit = N_limit
        self.swd_limit = swd_limit
        self.nn_real_indices = list(range(5))
        self._tracking_objects = False

    def track_loss(self, loss, label, part_losses=None):
        """Track losses

        INPUT:

            loss (float): loss value
            label (string): name of this loss
            iteration (int): how many iterations how the model been trained
            part_losses (list:None)

        """
        d = {f"{label}_total_loss": float(loss)}

        if part_losses is not None and isinstance(part_losses, (list, tuple)):
            base_loss, regularizer = part_losses
            d.update({f"{label}_base_loss": float(base_loss),
                      f"{label}_regularizer": float(regularizer)})

        self.wandb.log(d, commit=False)

    def add_images(self, images, image_names, iteration=''):
        """Calculate summary statistics

        INPUT:

            images (list): list of images
            image_names (list:string): list of names for each images

        """
        images = list(images)
        run_id = self.wandb.run.id
        imagelabels = [f"{run_id}_image{iteration}_from_{name}"
                       for name in image_names]

        images = [torch.clamp(image, 0, 1) for image in images]
        images = [transforms.ToPILImage()(image.to(device=torch.device('cpu'),
                                                   dtype=torch.float))
                  for image in images]
        self.wandb.log({'images': [self.wandb.Image(img,
                                                    caption=imagelabel)
                                   for img, imagelabel
                                   in zip(images, imagelabels)]},
                       commit=False)

    def track_summary_stats(self, generator, dataset, label=''):
        """Calculate summary statistics

        INPUT:

            generator (object, nn.Module): GAN generator with a sample() method
            dataset (torch.data.Dataset object): dataset to calculate FID from
            N (int): number of samples from generator to estimate FID
            label (string): label identifying the logged metric

        """

        generator.eval()
        samples = math.ceil(100/self.batch_size) * self.batch_size
        with torch.no_grad():
            dataloader, _ = make_dataloader([dataset, dataset],
                                            self.batch_size, 0)

            g = []
            gt = []
            data_iter = iter(dataloader)
            for i, (X, _) in enumerate(data_iter):
                fake = generator.sample(torch.Size([self.batch_size]))
                X = X.to(device=self.device)
                g.append(fake)
                gt.append(X)
                if i*X.size(0) >= samples:
                    break
            g = torch.cat(g, dim=0)
            gt = torch.cat(gt, dim=0)
            if g.size(0) != gt.size(0) != samples:
                raise ValueError("Number of samples mismatch "
                                 + "(samples != g.size(0))")
            # monitor outputs
            d = {f"g_mean_{samples}"+(f'-{label}' if label else ''): g.mean(dim=0).mean(),
                 f"g_var_{samples}"+(f'-{label}' if label else ''): g.var(dim=0).mean()}
            d_true = {f"mean_true_{samples}": gt.mean(dim=0).mean(),
                      f"var_true_{samples}": gt.var(dim=0).mean()}
            d.update(d_true)

            self.wandb.log(d, commit=False)

    def log(self, iteration, running_time):
        d = {"running_time": running_time,
             "iteration": iteration}
        self.wandb.log(d)

    @limit_n_decorator
    def fid_score(self, generator, dataset, N=int(1e3), label=''):
        """Calculate FID score

        INPUT:

            generator (object, nn.Module): GAN generator with a sample() method
            dataset (torch.data.Dataset object): dataset to calculate FID from
            N (int): number of samples from generator to estimate FID.
                     default value: 5e4
            label (string): label identifying the logged metric

        """

        generator.eval()
        fid_scores = {}
        f_score = fid_score(generator, dataset, self.batch_size,
                            next(generator.parameters()).is_cuda, 2048,
                            N, normalize_input=self._normalize_metric_input)
        fid_scores[f"fid-{N}" + (f'-{label}' if label else '')] = f_score
        self.wandb.log(fid_scores, commit=False)

    def set_neighbour_indices(self, indices):
        self.nn_real_indices = indices

    def process_neighbours(self, imgs):
        return imgs.flatten(start_dim=1)

    def neighbour_distance(self, n1, n2):
        return (n1.unsqueeze(0)-n2.unsqueeze(1)).abs().sum(dim=-1)

    def start_nearest_neighbours(self, dataset, N, label):
        if not hasattr(self, 'nearest_neighbours'):
            self.nearest_neighbours = {}
        self.nn_N = N
        self.nn_label = label
        self.nn_reals = self.process_neighbours(
            torch.stack([dataset[i][0] for i in self.nn_real_indices]).cuda())
        self.nn_n_generated = 0
        self.nn_neighbours = [None,] * len(self.nn_real_indices)
        self.nn_distances = [torch.tensor(float('inf')).cuda(),] * len(self.nn_real_indices)

    def update_nearest_neighbours(self, fake_batch):
        if self.nn_n_generated >= self.nn_N:
            return
        fakes = self.process_neighbours(fake_batch)
        distances = self.neighbour_distance(self.nn_reals, fakes)
        min_values, min_indices = distances.min(dim=0)
        for i_real, (val, ind) in enumerate(zip(min_values, min_indices)):
            if val < self.nn_distances[i_real]:
                self.nn_neighbours[i_real] = fake_batch[ind]
                self.nn_distances[i_real] = val
        self.nn_n_generated += self.batch_size
        self.nearest_neighbours[(self.nn_N, self.nn_label)] = list(self.nn_neighbours)  # clone

    def log_nearest_neighbours(self):
        # add images from nearest neighbours as well if this was called
        for (n, label), neighbours in self.nearest_neighbours.items():
            nn_images = neighbours
            nn_labels = [f'neighbour_{n}_{i}'+('' if label == '' else f'_{label}')
                         for i, _ in enumerate(neighbours)]
            self.wandb.log({f'neighbours_{n}_{label}': [self.wandb.Image(img, caption=label)
                           for img, label in zip(nn_images, nn_labels)]},
                           commit=False)
        self.nearest_neighbours = {}

    @limit_n_decorator
    def inception_score(self, generator, N=int(5e4), label=''):
        """Calculate inception score

        INPUT:

            generator (object, nn.Module): GAN generator with a sample() method
            N (int): number of samples from generator to estimate the distance,
                     default value: 5e4
            label (string): label identifying the logged metric

        """
        generator.eval()
        inception_scores = {}
        ic_score = inception_score(generator, N=N,
                                   cuda=next(generator.parameters()).is_cuda,
                                   batch_size=self.batch_size, resize=True,
                                   splits=10,
                                   normalize_input=self._normalize_metric_input)
        ic_score_mean, ic_score_std = ic_score
        inception_scores[f'inception-mean-{N}'
                         + (f'-{label}' if label else '')] = ic_score_mean
        inception_scores[f'inception-std-{N}'
                         + (f'-{label}' if label else '')] = ic_score_std
        self.wandb.log(inception_scores, commit=False)

    @limit_n_decorator
    def swd_metric(self, generator, dataset, N=int(1e3), label='',
                   and_get_nearest_neighbours=True):
        """Calculate sliced wasserstein distance score

        INPUT:

            generator (object, nn.Module): GAN generator with a sample() method
            dataset (torch.data.Dataset object): dataset to calculate sliced
                                                 wasserstein distance from
            N (int): number of samples from generator to estimate the distance
                     default value: 1e3
            label (string): label identifying the logged metric
            and_get_nearest_neighbours: optionally also log the nearest neighbours -
                this is efficient as it uses the same generated images as the SWD

        """
        generator.eval()
        if and_get_nearest_neighbours:
            self.start_nearest_neighbours(dataset=dataset, N=N, label=label)

        if self.swd_limit is not None and N > self.swd_limit:
            N = self.swd_limit
        with torch.no_grad():
            dataloader, _ = make_dataloader([dataset, dataset],
                                            batch_size=self.batch_size,
                                            num_workers=self.num_workers)
            dataloader = iter(dataloader)

            SWD_tracker = None
            n_generated = 0
            n_batch = self.batch_size
            while n_generated < N:
                fake_batch = generator.sample(torch.Size([n_batch]))
                if and_get_nearest_neighbours:
                    self.update_nearest_neighbours(fake_batch)
                try:
                    _, C, H, W = fake_batch.shape
                except TypeError:
                    # not 4 dimensional tensors so probably is 1d experiment
                    self.swd_1D_metric(generator, dataset, N)
                    return
                if SWD_tracker is None:
                    SWD_tracker = SWD(C, H, W, smallest_res=12, save_to_disk=(N>1000))
                real_batch = next(dataloader)[0]
                n_batch_kept = min(n_batch, N-n_generated)
                SWD_tracker.project_images(fake_batch[:n_batch_kept], real_batch[:n_batch_kept])
                n_generated += n_batch_kept

            swds = {f'swd-{n_generated}'+(f'-{label}' if label else ''): SWD_tracker.get_swd()}
            self.wandb.log(swds, commit=False)
            SWD_tracker.clean_up()
  
        if and_get_nearest_neighbours:
            self.log_nearest_neighbours()

    @limit_n_decorator
    def swd_1D_metric(self, generator, dataset, N, label=''):
        """Calculate 1D wasserstein distance

        INPUT:

            generator (object, nn.Module): GAN generator with a sample() method
            dataset (torch.data.Dataset object): dataset to calculate
                                                 wasserstein distance from
            N (int): number of samples to use to estimate wasserstein
                     distance.
            label (string): label identifying the logged metric

        """
        generator.eval()

        # probably could be made ~100x neater
        with torch.no_grad():
            train_loader, _ = make_dataloader([dataset, dataset],
                                              batch_size=N,
                                              num_workers=self.num_workers)
            n_fake = N
            all_fakes = []
            n_generated = 0
            while n_generated < n_fake:
                n_batch = min(self.batch_size, n_fake-n_generated)
                all_fakes.append(generator.sample(torch.Size([n_batch])))
                n_generated += n_batch
            all_fakes = torch.cat(all_fakes, dim=0)
            all_fakes_npy = all_fakes.flatten().cpu().numpy()
            reals = next(iter(train_loader))[0].squeeze(1)
            reals_npy = reals.cpu().numpy()
            distance = scipy.stats.wasserstein_distance(reals_npy,
                                                        all_fakes_npy)
            self.wandb.log({f'wd_1d'+(f'-{label}' if label else ''): distance},
                           commit=False)

    def save_objects(self, objects, names, local_save=None, **kwargs):
        print("Saving checkpoints")
        base_path = self.wandb.run.dir if local_save is None else local_save
        if len(objects) != len(names):
            raise ValueError("Number of names must match number of models")
        for name in ['iteration', 'running_time']:
            if name not in names:
                raise ValueError(f'Ensure we are logging "{name}"')

        run_id = self.wandb.run.id
        for obj, name in zip(objects, names):
            path = os.path.join(base_path, name)
            if getattr(obj, "state_dict", None) is not None:
                torch.save(obj.state_dict(), path)
            else:
                torch.save(obj, path)

        object_name = f'object_names_{run_id}.pickle'
        path = os.path.join(base_path, object_name)
        with open(path, 'wb') as f:
            pickle.dump(names, f)

        if local_save is None and not self._tracking_objects:
            self.wandb.save(os.path.join(base_path, '*pickle'),
                            base_path=base_path)
        self._tracking_objects = True

    def load_objects(self, run_path, local_save=None, **kwargs):
        def load_pickle(path_to_object):
            print('loading onject from', path_to_object)
            with open(path_to_object, 'rb') as f:
                obj = pickle.load(f)
            return obj

        run_id = self.wandb.run.id
        if local_save is None:
            path = self.wandb.restore(f'object_names_{run_id}.pickle',
                                      run_path=run_path.lower(),
                                      replace=True).name
        else:
            path = os.path.join(local_save, f'object_names_{run_id}.pickle')
        object_names = load_pickle(path)

        objects = {}
        for obj_name in object_names:
            if local_save is None:
                path = self.wandb.restore(obj_name, run_path=run_path.lower(),
                                          replace=True).name
            else:
                path = os.path.join(local_save, obj_name)
            objects[obj_name] = torch.load(path)
        return objects


def collate_fn(batch):
    """ Collate function used in DP

    Args:

    batch (list): list of length batch_size. Each element is assumed (sample, label)

    Returns:

    samples, labels (tuple): batched samples and labels

    """

    samples, labels = zip(*batch)

    return torch.stack(samples), torch.stack(labels)


def make_dataloader(datasets, batch_size, num_workers):
    if not len(datasets) == 2:
        raise ValueError("Datasets must contain both training and validation")
    else:
        dataloader_train = DataLoader(datasets[0], batch_size=batch_size,
                                      shuffle=True, num_workers=num_workers,
                                      pin_memory=True)

        dataloader_val = DataLoader(datasets[1], batch_size=batch_size,
                                    shuffle=True)

        return dataloader_train, dataloader_val


def log1mexp(x):
    """Calculate log1mexp(x) = log(1-exp(x)) stable for x < 0

    see - https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf

    Args:

    x (tensor): tensor on which we apply log1mexp

    Returns:

    y (tensor): the values of log1mexp(x)

    """
    if not torch.all(x < 0):
        raise ValueError("arguments to log1mexp must be less than 0")

    mask = (x < torch.log(torch.Tensor([0.5])))
    y = torch.zeros_like(x)

    y[mask] = torch.log1p(-torch.exp(x[mask]))
    y[~mask] = torch.log(-torch.expm1(x[~mask]))

    return y



def check_generators_valid(p1, p2):
    "Run a try / except to test the given proxy model"
    pass

class NormalizeType(Enum):
    Standard = 0
    Total = 1
    Advas = 2
