import torch
import skimage.io as io
import copy
import random
import numpy as np
from torch import nn
from tqdm import tqdm
from matplotlib import pyplot as plt
import os
from torchmetrics.image import PeakSignalNoiseRatio
import mlflow
from mlflow.models import infer_signature
from torchvision.datasets.cifar import CIFAR10
from fld.features.InceptionFeatureExtractor import InceptionFeatureExtractor
from torchvision import datasets, transforms
from torchvision.utils import save_image
from fld.metrics.FID import FID
from torchdiffeq import odeint
from deepinv.loss.regularisers import JacobianSpectralNorm

from denflow.utils import ema, model_mul
from denflow.dataloaders import CelebADataset

img_dir_celeba = './data/celeba/img_align_celeba/'
partition_csv_celeba = './data/celeba/list_eval_partition.csv'


class GENERAL_DENOISER(torch.nn.Module):

    def __init__(self, model, loss_denoising, class_denoiser, device, args):
        super().__init__()
        self.d = args.dim_image
        self.num_channels = args.num_channels
        self.device = device
        self.args = args
        self.loss_denoising = loss_denoising
        self.class_denoiser = class_denoiser
        self.lr = args.lr
        self.model = model.to(device)
        self.psnr = PeakSignalNoiseRatio(data_range=2.0).to(device)
        self.criterion = nn.MSELoss(reduction='mean')
        self.time_start = 0.00
        self.time_end = 1.0
        self.ema_decay = args.ema_decay
        self.ema_start = args.ema_start
        self.ema = args.ema
        self.a = 1.0 / (1.0 + getattr(self.args, "sigma_max", 1.0))
        self.jacobian_penalisation = self.args.jacobian_penalisation
        self.jacobian_penalisation_activated = False
        if self.jacobian_penalisation:
            self.jacobian_loss = JacobianSpectralNorm(
                max_iter=10, tol=1e-5, eval_mode=True, verbose=False, reduction=None)
            self.lmbda = 0.1
            self.eps_jacob = 0.1

    def calculate_grad(self, x, t):
        '''
        Calculate Dg(x) the gradient of the regularizer g at input x
        :param x: torch.tensor Input image
        :param sigma: Denoiser level (std)
        :return: Dg(x), DRUNet output N(x)
        '''
        x = x.float()
        x = x.requires_grad_()
        with torch.enable_grad():
            N = self.model(x, t)
        JN = torch.autograd.grad(
            N, x, grad_outputs=x - N, create_graph=True, only_inputs=True)[0]
        Dg = x - N - JN
        return Dg

    def get_denoiser(self, xt, t):
        """
        xt: [B,C,H,W]
        t : [B]
        """
        B = xt.shape[0]
        t_b = t.view(B, *([1] * (xt.ndim - 1)))  # [B,1,1,1] for broadcasting

        if self.class_denoiser == "NN":
            return self.model(xt, t)

        elif self.class_denoiser == "shifted":
            # D = x + (1-t) N
            return xt + (1 - t_b) * self.model(xt, t)

        elif self.class_denoiser == "gradient_step":
            Dg = self.calculate_grad(xt, t)
            return xt - Dg

        elif self.class_denoiser == "gradient_step_shifted":
            Dg = self.calculate_grad(xt, t)
            return xt - (1 - t_b) * Dg

        else:
            raise ValueError("The class_denoiser you have given does not exist")

    def get_velocity(self, xt, t):
        """
        xt: [B,C,H,W]
        t : [B]
        """
        B = xt.shape[0]
        t_b = t.view(B, *([1] * (xt.ndim - 1)))  # [B,1,1,1]

        if self.class_denoiser == "gradient_step_shifted":
            Dg = self.calculate_grad(xt, t)
            return -Dg  # D = x - (1-t)Dg  =>  V = (D-x)/(1-t) = -Dg

        elif self.class_denoiser == "shifted":
            # D = x + (1-t)N  =>  V = N
            return self.model(xt, t)

        else:
            self.model.time_end = 0.999
            denoiser = self.get_denoiser(xt, t)
            return (denoiser - xt) / torch.clamp(1 - t_b, min=1e-3)

    def get_loss(self, clean_imgs):
        B = clean_imgs.size(0)
        noise = torch.randn_like(clean_imgs)

        # pass [B] to the model path
        if self.loss_denoising == "classic":
            t = self.a + (1 - self.a) * torch.rand(B, device=self.device)
            t_b = t.view(B, *([1] * (clean_imgs.ndim - 1)))
            xt = t_b * clean_imgs + (1 - t_b) * noise
            x_hat = self.get_denoiser(xt, t)
            w = torch.clamp(t_b, min=1e-3)
            self.psnr.update(x_hat.detach(), clean_imgs.detach())
            return self.criterion(x_hat / w, clean_imgs / w)
        else:

            # sample t as [B]
            # t = torch.rand(B, device=self.device)
            t = self.args.t_min + \
                torch.rand(B, device=self.device) * (self.args.t_max - self.args.t_min)
            t_b = t.view(B, *([1] * (clean_imgs.ndim - 1)))  # [B,1,1,1]

            xt = t_b * clean_imgs + (1 - t_b) * noise
            x_hat = self.get_denoiser(xt, t)
            self.psnr.update(x_hat.detach(), clean_imgs.detach())
            if self.loss_denoising == "FM":

                criterion = self.criterion(
                    x_hat / torch.clamp(1 - t_b, min=1e-3), clean_imgs / torch.clamp(1 - t_b, min=1e-3))

            elif self.loss_denoising == "den":
                criterion = self.criterion(x_hat, clean_imgs)
            elif self.loss_denoising == "pow_1":
                weighting_sqrt = torch.clamp(1 - t_b, min=1e-3) ** 0.5
                criterion = self.criterion(
                    x_hat / weighting_sqrt, clean_imgs / weighting_sqrt)
            elif self.loss_denoising == "pow_3":
                weighting_sqrt = torch.clamp(1 - t_b, min=1e-3) ** 1.5
                criterion = self.criterion(
                    x_hat / weighting_sqrt, clean_imgs / weighting_sqrt)
            elif self.loss_denoising == "mid":
                weighting_sqrt = torch.clamp(0.5 - t_b, min=1e-3)
                criterion = self.criterion(
                    x_hat / weighting_sqrt, clean_imgs / weighting_sqrt)
            elif self.loss_denoising == "mid_02":
                weighting_sqrt = torch.clamp(0.2 - t_b, min=1e-3)
                criterion = self.criterion(
                    x_hat / weighting_sqrt, clean_imgs / weighting_sqrt)
            elif self.loss_denoising == "mid_01_1":
                weighting_sqrt = torch.clamp((0.1 - t_b) * (1 - t_b), min=1e-3)
                criterion = self.criterion(
                    x_hat / weighting_sqrt, clean_imgs / weighting_sqrt)
            elif self.loss_denoising == "avg":
                weighting_sqrt = torch.sqrt(
                    0.5 * (1 / torch.clamp(t_b, min=1e-2)**2 + 1 / torch.clamp(1-t_b, min=1e-3)**2))
                criterion = self.criterion(
                    x_hat * weighting_sqrt, clean_imgs * weighting_sqrt)
            elif self.loss_denoising == "one_over_tsquared":
                weighting_sqrt = torch.clamp(t_b, min=1e-3)
                criterion = self.criterion(
                    x_hat / weighting_sqrt, clean_imgs / weighting_sqrt)
            elif self.loss_denoising == "scaled":
                criterion = self.criterion(t_b * x_hat, t_b * clean_imgs)
            else:
                raise ValueError("The loss_denoising you have given does not exist")

            if self.jacobian_penalisation_activated:
                # mask for the interval
                mask = (t >= self.args.t_jac_min) & (t <= self.args.t_jac_max)

                if mask.any():
                    # keep only samples in the interval
                    xt_masked = xt[mask].clone().requires_grad_()
                    t_masked = t[mask]

                    x_hat_masked = self.get_denoiser(xt_masked, t_masked)
                    jacobian_norm = self.jacobian_loss(x_hat_masked, xt_masked)

                    # penalize only those samples
                    pen = torch.maximum(
                        jacobian_norm,
                        self.args.max_jacob *
                        torch.ones_like(jacobian_norm) - self.eps_jacob
                    ).mean()

                    return criterion + self.lmbda * pen
                else:
                    return criterion
            else:
                return criterion

    def train_denoiser(self, train_loader, opt, scheduler, num_epoch):
        tq = tqdm(range(num_epoch), desc='loss')
        max_iters = num_epoch // self.args.batch_size_train

        ft_extractor = InceptionFeatureExtractor(save_path="features")
        if self.args.dataset == "cifar10":
            train_feat = ft_extractor.get_features(
                CIFAR10(train=True, root="data", download=True), name="cifar10_train")
        # elif self.args.dataset == "celeba64" or self.args.dataset == "celeba":
            # train_feat = ft_extractor.get_features(
            #     CelebADataset(
            #         img_dir_celeba, partition_csv_celeba, partition=0, transform=transforms.Compose([transforms.CenterCrop(178), transforms.Resize([self.args.dim_image, self.args.dim_image]),])), name=f"celeba{self.args.dim_image}_train")
            # train_feat = ft_extractor.get_features(CelebADataset(
            #     img_dir_celeba, partition_csv_celeba, partition=2, transform=transforms.Compose([transforms.CenterCrop(178), transforms.Resize([self.args.dim_image, self.args.dim_image]),])), name=f"celeba{self.args.dim_image}_test")

        global_step = 0  # to keep track of the global step in training loop
        # self.model.train()
        for ep in tq:

            self.psnr.reset()
            self.ep = ep
            for iteration, (clean_imgs, labels) in enumerate(train_loader):
                print("Epoch: {}, Iteration: {}".format(ep, iteration))
                global_step += 1

                if clean_imgs.size(0) == 0:
                    continue
                clean_imgs = clean_imgs.to(self.device, non_blocking=True)
                loss = self.get_loss(clean_imgs)
                opt.zero_grad(set_to_none=True)
                loss.backward()

                if self.args.grad_clip > 0:
                    torch.nn.utils.clip_grad_norm_(
                        self.model.parameters(), self.args.grad_clip)

                opt.step()
                scheduler.step()

                if self.ema:
                    if self.ema_start == 0:  # default ema, actually somewhat wrong, too much weight on the first model
                        ema(self.model, self.ema_model, self.ema_decay)
                    else:
                        if global_step == self.ema_start:
                            unbalanced_ema_model = copy.deepcopy(self.model)
                            # unb = (1-ema_decay) * M0
                            model_mul(1-self.ema_decay, self.model,
                                      unbalanced_ema_model)
                        elif global_step > self.ema_start:
                            ema(self.model, unbalanced_ema_model, self.ema_decay)
                            _factor = 1 / (1 - self.ema_decay **
                                           (1 + global_step - self.ema_start))
                            model_mul(_factor, unbalanced_ema_model, self.ema_model)

                with torch.no_grad():

                    if global_step % 100 == 0:
                        mlflow.log_metric(
                            "Loss", loss.item(), step=global_step)
                        mlflow.log_metric("LR", scheduler.get_last_lr()[
                            0], step=global_step)

                    if ep % 25 == 0 and iteration == 1:
                        B = clean_imgs.shape[0]
                        t = torch.rand(B, device=self.device)
                        x0 = torch.randn_like(clean_imgs)
                        # signature = infer_signature(
                        #     x0.cpu().numpy(),
                        #     self.model(x0, t).detach().cpu().numpy())
                        # mlflow.pytorch.log_model(
                        #     self.model, artifact_path=f'denoiser_epoch{ep}', signature=signature)
                        mlflow.pytorch.log_model(
                            self.model, artifact_path=f'denoiser_epoch{ep}')
                        if self.ema:
                            # mlflow.pytorch.log_model(
                            #     self.ema_model, artifact_path=f'denoiser_ema_epoch{ep}', signature=signature)
                            mlflow.pytorch.log_model(
                                self.ema_model, artifact_path=f'denoiser_ema_epoch{ep}')
                            # print("FID 5K")
                            # num_gen = 5_000
                            # fid = self.compute_fid(num_gen, train_feat,
                            #                        ft_extractor, batch_size=124, integration_method="euler", integration_steps=10, use_ema=True)
                            # metric_title = f"FID - {(num_gen // 1000)}k euler 10 steps"
                            # mlflow.log_metric(
                            #     metric_title, fid, step=global_step)
                    if ep == 380:
                        if self.jacobian_penalisation:
                            self.jacobian_penalisation_activated = True

            with torch.no_grad():
                psnr = self.psnr.compute()
                mlflow.log_metric(
                    "PSNR", psnr.item(), step=global_step)

        with torch.no_grad():
            signature = infer_signature(
                x0.cpu().numpy(),
                self.model(x0, t).detach().cpu().numpy())
            mlflow.pytorch.log_model(
                self.model, artifact_path=f'denoiser_final', signature=signature)

            if self.ema == True:
                mlflow.pytorch.log_model(
                    self.ema_model, artifact_path=f'denoiser_ema_final', signature=signature)

    def generate_samples(self, integration_method="dopri5", tol=1e-5,
                         n_samples=1028, batch_size=None, num_channels=3,
                         integration_steps=100, tmax=1, use_ema=False):
        """
        Return a tensor of size (TODO).
        """

        if use_ema:
            network = self.ema_model
        else:
            network = self.model

        if batch_size is None:
            batch_size = n_samples

        images_list = []
        batches = [batch_size] * (n_samples // batch_size)
        if n_samples % batch_size:
            batches += [n_samples % batch_size]

        with torch.no_grad():
            for k, batch in enumerate(tqdm(batches)):
                time_points = torch.linspace(
                    0, tmax, int(tmax * integration_steps), device=self.device)

                x0 = torch.randn(batch, num_channels, self.d,
                                 self.d, device=self.device)
                traj = odeint(
                    torch_model_wrapper(network, self.class_denoiser, self.device, self.args), x0, time_points, rtol=tol, atol=tol,
                    method=integration_method)
                images_list.append(traj[-1, :])

        images = torch.cat(images_list, dim=0)
        return images

    def compute_fid(self, num_images_fid, train_feat, ft_extractor, batch_size=512, integration_method="dopri5", integration_steps=100,  epoch='final', use_ema=False):
        gen_images = self.generate_samples(integration_method=integration_method, tol=1e-4,
                                           n_samples=num_images_fid, batch_size=batch_size, integration_steps=integration_steps, use_ema=use_ema)
        rescaled_imgs = (gen_images * 127.5 + 128).clip(0, 255).to(torch.uint8)
        gen_feat = ft_extractor.get_tensor_features(
            rescaled_imgs)

        fid_val = FID().compute_metric(
            train_feat, None, gen_feat)

        # save the 16 first generated images in a grid
        os.makedirs(f"training_images/{self.args.dataset}", exist_ok=True)
        images = gen_images[:16]
        save_image(
            images, f"training_images/{self.args.dataset}/gen_images_epoch{epoch}.png")
        return fid_val

    def train(self, data_loaders):

        print("Training about to start")
        mlflow.set_experiment(f"{self.args.dataset}")
        mlflow.start_run()
        print("Training started")
        sigma_max = getattr(self.args, "sigma_max", 1.0)
        mlflow.set_tag(
            "mlflow.runName",
            f"object {self.args.train_object} loss {self.loss_denoising}, class {self.class_denoiser}, bs {self.args.batch_size_train}, lr {self.args.lr}, epochs_max {self.args.num_epoch}, t_min {self.args.t_min}, t_max {self.args.t_max}, sigma_max {sigma_max} jacobian_penalisation {self.jacobian_penalisation} max_jacob {self.args.max_jacob}")
        params_dict = vars(self.args)
        # Log all parameters to MLflow
        mlflow.log_params(params_dict)
        mlflow.log_artifact('config_dict.pkl', artifact_path="config")
        train_loader = data_loaders['train']
        del data_loaders

        warmup = 5000

        def warmup_lr(step):
            return min(step, warmup) / warmup
        opt = torch.optim.Adam(self.model.parameters(), lr=self.args.lr)
        scheduler = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda=warmup_lr)

        if self.ema == True:
            self.ema_model = copy.deepcopy(self.model)

        self.train_denoiser(train_loader, opt, scheduler, num_epoch=self.args.num_epoch)


class torch_model_wrapper(GENERAL_DENOISER):
    """Wraps model to torchdyn compatible format."""

    def __init__(self, model, class_denoiser, device, args):
        super().__init__(model=model, class_denoiser=class_denoiser,
                         loss_denoising=None, args=args, device=device)

    def forward(self, t, x):
        t_ = t * torch.ones(len(x), device=self.device)
        return self.get_velocity(x, t_).detach()
