# ---------------------------------------------------------------
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# This work is licensed under the NVIDIA Source Code License
# for I2SB. To view a copy of this license, see the LICENSE file.
# ---------------------------------------------------------------

import os
import numpy as np
import pickle

import torch
import torch.nn.functional as F
from torch.optim import AdamW, lr_scheduler
from torch.nn.parallel import DistributedDataParallel as DDP

from torch_ema import ExponentialMovingAverage
import torchvision.utils as tu
import torchmetrics
import matplotlib.pyplot as plt

from . import util
from .network import SharpNet
from .adaptive_network import Adative_Network
import random
from ipdb import set_trace as debug
from tqdm import tqdm
# import corruption.superresolution as sr
import corruption.sisr as sisr
import corruption.blur as blur
def make_beta_schedule(n_timestep=1000, linear_start=1e-4, linear_end=2e-2):
    betas = (
        torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
    )
    return betas.numpy()


def all_cat_cpu(opt, log, t):
    if not opt.distributed: return t.detach().cpu()
    gathered_t = dist_util.all_gather(t.to(opt.device), log=log)
    return torch.cat(gathered_t).detach(). cpu()


class Runner(object):
    def __init__(self, opt, log):
        super(Runner, self).__init__()

        try:
            betas = make_beta_schedule(n_timestep=opt.interval, linear_start=opt.beta_min / opt.interval,
                                       linear_end=opt.beta_max / opt.interval)
        except:
            betas = make_beta_schedule(n_timestep=opt.interval, linear_end=opt.beta_max / opt.interval)

        betas = np.concatenate([betas[:opt.interval//2], np.flip(betas[:opt.interval//2])])
        restoration_level = torch.linspace(opt.t0, opt.T, opt.interval, device=opt.device) * opt.interval
        self.net = SharpNet(restoration_level=restoration_level)
        self.ema = ExponentialMovingAverage(self.net.parameters(), decay=opt.ema)
        self.adaptive_net = Adative_Network(betas, opt.device, self.net)


        if opt.load:
            checkpoint = torch.load(opt.load, map_location="cpu")
            checkpoint['net'] = {key.replace("diffusion_model", "SharpNet"): value for key, value in checkpoint['net'].items()}

            self.net.load_state_dict(checkpoint['net'])
            log.info(f"[Net] Loaded network ckpt: {opt.load}!")
            self.ema.load_state_dict(checkpoint["ema"])
            log.info(f"[Ema] Loaded ema ckpt: {opt.load}!")

        self.net.to(opt.device)
        self.ema.to(opt.device)
        self.device = opt.device

        self.log = log

    def sharp_iteration(self, steps, x_init, y, kernel_A, kernel_H, t_all):
        xt = x_init.detach().to(self.device)
        xs = []
        xs_forward = []
        pred_x0s = []

        i = 0
        for step in tqdm(range(steps)):
            y = y.to(self.device)
            residual = sisr.fmult(xt.detach(), kernel_A) - y
            gradient = sisr.ftran(residual.cpu(), kernel_A).to(self.device)

            x_input = blur.fmult(xt, kernel_H).to(xt.device)
            pred_x0 = xt

            t = t_all[random.randint(0, 5)]

            pred_x0, mu_xn, mu_x0 = self.adaptive_net.pred(x_input, pred_x0, t=t, total_nfe=999)


            residual = xt - pred_x0
            Hprior = blur.fmult(residual, kernel_H).to(residual.device)
            #
            Hprior = mu_xn * Hprior + mu_x0 * residual
            # #
            HTHprior = blur.ftran(Hprior, kernel_H).to(Hprior.device)

            #
            HTHprior = mu_xn * HTHprior + mu_x0 * residual
            prior_term = HTHprior

            ##### 1.5 with noise###
            if step < 200:
                xt = xt - 0.1 * (gradient + 0.6 * prior_term)
            else:
                xt = xt - 0.05 * (gradient + 0.6 * prior_term)



            pred_x0s.append(pred_x0.detach().cpu())
            xs.append(pred_x0.detach().cpu())
            xs_forward.append(xt.detach().cpu())

        stack_bwd_traj = lambda z: torch.flip(torch.stack(z, dim=1), dims=(1,))
        return stack_bwd_traj(xs_forward), stack_bwd_traj(pred_x0s)

    # @torch.no_grad()
    # def run_sharp(self, opt, x_init, y, nfe=300, stochastic=True):
    #     x_init = x_init.to(opt.device)
    #
    #     with self.ema.average_parameters():
    #         self.net.eval()
    #         xs_forward, pred_x0 = self.sharp_iteration(
    #             nfe, x_init, y)
    #
    #     return xs_forward, pred_x0