import math

import torch
import torch.nn as nn
from torch_utils import persistence
from torch_utils import distributed as dist
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

# ode class
@persistence.persistent_class
class ode_ve:
    def __init__(self, t_max = 80.0, t_min=0.002):
        self.t_max = t_max
        self.t_min = t_min
        self.ode_type = 've'
    
    def __call__(self, images, t, eps):
        yt = images + eps * t
        return yt

@persistence.persistent_class
class ode_rflow:
    def __init__(self, t_max = 1.0, t_min=0.002):
        self.t_max = t_max
        self.t_min = t_min
        self.ode_type = 'rflow'
    
    def __call__(self, images, t, eps):
        yt = images * (1-t) + eps * t
        return yt
    

@persistence.persistent_class
class loss_ADCM:
    def __init__(self, c=0.03, mean=-1.1, std=2.0, lamda=0.01, sample_type='lognormal', weight_type='adaptive'):
        self.ts = None
        self.c = c
        self.P_mean = mean
        self.P_std = std
        self.lamda = lamda
        self.sample_type = sample_type
        self.weight_type = weight_type

    def lognormal_timestep_distribution(
            self,
            num_samples,
            sigmas
    ):
        pdf = torch.erf((torch.log(sigmas[:-1]) - self.P_mean) / (self.P_std * math.sqrt(2))) - torch.erf(
            (torch.log(sigmas[1:]) - self.P_mean) / (self.P_std * math.sqrt(2))
        )
        pdf = pdf / pdf.sum()
        timesteps = torch.multinomial(pdf, num_samples, replacement=True)
        return timesteps


    def uniform_timestep_distribution(
            self,
            num_samples,
            sigmas
    ):
        timesteps = torch.randint(0, sigmas.shape[0]-1, (num_samples,))
        return timesteps

    def update(self, net=None, dataset_iterator=None, device=None, ode_fn = None):

        net.requires_grad_(False)
        ts = torch.empty((0)).to(device)
        t = torch.tensor([ode_fn.t_max]).to(device)[:None]

        while t > ode_fn.t_min:
            ts = torch.cat((ts, t), dim=0)
            jjf = torch.zeros([2, 1]).to(device)
            images, labels = next(dataset_iterator)
            images = images.to(device).to(torch.float32) / 127.5 - 1
            labels = labels.to(device)
            def j(yt, t):
                out = net(yt, t, labels)
                return out
            eps = torch.randn_like(images)
            yt = ode_fn(images=images, t=t, eps=eps)
            if ode_fn.ode_type == 've':
                F_theta, F_theta_grad = torch.func.jvp(
                        j,
                        primals=(yt, t),
                        tangents=(eps, torch.ones_like(t)))
            else:
                F_theta, F_theta_grad = torch.func.jvp(
                        j,
                        primals=(yt, t),
                        tangents=(eps - images, torch.ones_like(t)))
            F_theta_grad = F_theta_grad.detach()
            F_theta = F_theta.detach()
            with torch.no_grad():
                jj = F_theta_grad * F_theta_grad
                jj = torch.sum(jj.reshape(jj.shape[0], -1), dim=-1)
                jf = F_theta_grad * (F_theta - images)
                jf = torch.sum(jf.reshape(jf.shape[0], -1), dim=-1)
                jjf[0] += jj.mean()
                jjf[1] += jf.mean()
            delta = jjf[1] / (jjf[0] * (1.0 + 1.0 / self.lamda))
            if delta.abs() >= ode_fn.t_max:
                t = t - ode_fn.tmax / 8.0
            else:
                t = t - delta.abs()
            if t <= ode_fn.t_min:
                ts = torch.cat((ts, torch.tensor([ode_fn.t_min]).to(device)), dim=0)
        net.requires_grad_(True)

        self.ts = ts.clone()
        # if dist.get_rank() == 0:
        #     print(ts)
        #     torch.save(ts, f"tensor{self.lamda:.5f}.pt")

    def __call__(self, net, images, eps, ode_fn=None, labels=None, augment_labels=None, teacher=None, ratio=None):
        sigmas = self.ts.to(images.device)

        if self.sample_type == 'lognormal':
            if ode_fn.ode_type == 'rflow':
                timesteps = self.lognormal_timestep_distribution(images.shape[0], sigmas / (1.0 -sigmas)).to(images.device)
            else:
                timesteps = self.lognormal_timestep_distribution(images.shape[0], sigmas).to(images.device)
        else:
            timesteps = self.uniform_timestep_distribution(images.shape[0], sigmas).to(images.device)
        r = sigmas[timesteps+1].clone()[:, None, None, None]
        t = sigmas[timesteps].clone()[:, None, None, None]

        yt = ode_fn(images=images, t=t, eps=eps)
        yr = ode_fn(images=images, t=r, eps=eps)
        rng_state = torch.cuda.get_rng_state()
        D_yt = net(yt, t, labels, augment_labels=augment_labels)
        torch.cuda.set_rng_state(rng_state)
        with torch.no_grad():
            D_yr = net(yr, r, labels, augment_labels=augment_labels)
        loss = (D_yt - D_yr) ** 2
        loss = torch.sum(loss.reshape(loss.shape[0], -1), dim=-1)
        loss = torch.sqrt(loss + self.c**2) - self.c
        if self.weight_type == 'adaptive':
            r_0 = (D_yr - images) ** 2
            r_0 = torch.sum(r_0.reshape(r_0.shape[0], -1), dim=-1)
            r_0 = torch.sqrt(r_0+ self.c**2) - self.c
            loss = loss / r_0
        else:
            loss = loss / (t-r).flatten()
        return loss, ((t - r) / t).mean()





@persistence.persistent_class
class loss_SCM:
    def __init__(self, c=0.03, mean=-1.1, std=2.0, lamda=0.01, sample_type='lognormal', weight_type='adaptive'):
        self.ts = None
        self.c = c
        self.P_mean = mean
        self.P_std = std
        self.lamda = lamda
        self.sample_type = sample_type
        self.weight_type = weight_type

    def update(self, net=None, dataset_iterator=None, device=None, ode_fn = None):
        a = 1
        print(a)

    def __call__(self, net, images, eps, ode_fn=None, labels=None, augment_labels=None, teacher=None, ratio=None):
        def j(yt, t):
            out = teacher(yt, t)
            return out
        ratio = min(ratio, 1.0)
        rnd_normal = torch.randn([images.shape[0], 1, 1, 1], device=images.device)
        t = (rnd_normal * self.P_std + self.P_mean).exp()
        yt = ode_fn(images=images, t=t, eps=eps)
        rng_state = torch.cuda.get_rng_state()
        D_yt = net(yt, t, labels, augment_labels=augment_labels)
        torch.cuda.set_rng_state(rng_state)
        if ode_fn.ode_type == 've':
            F_theta, F_theta_grad = torch.func.jvp(
                    j,
                    primals=(yt, t),
                    tangents=(eps, torch.ones_like(t)))
        elif ode_fn.ode_type == 'rflow':
            F_theta, F_theta_grad = torch.func.jvp(
                    j,
                    primals=(yt, t),
                    tangents=(eps - images, torch.ones_like(t)))
        else:
            F_theta, F_theta_grad = torch.func.jvp(
                    j,
                    primals=(yt, t),
                    tangents=(-torch.sin(t) * images + 0.5 * torch.cos(t) * eps, torch.ones_like(t)))
        g = F_theta_grad.detach()
        F_theta = F_theta.detach()
        # g = 0.25 * eps / (0.25+t**2) + 0.125*F_theta / (0.25+t**2)**1.5 + ratio*(0.5*t*g / (0.25+t**2)**0.5 - 0.5*t*yt / (0.25+t**2)**2)
        g_norm = torch.linalg.vector_norm(g, dim=(1, 2, 3), keepdim=True)
        g_norm = g_norm * np.sqrt(g_norm.numel() / g.numel())  
        g = g / (g_norm + 0.1)
        # F_theta = 0.25 * yt / (0.25+t**2) + 0.5*t*F_theta / (0.25+t**2)**0.5
        if ode_fn.ode_type == 've':
            loss = (g / t + D_yt - F_theta) ** 2
        else:
            loss = (g * (1-t) / t + D_yt - F_theta) ** 2
        loss = torch.sum(loss.reshape(loss.shape[0], -1), dim=-1)
        # loss = (torch.exp(logvar) / 3072)*loss - logvar
        return loss, (t/ t).mean()


@persistence.persistent_class
class loss_CT:
    def __init__(self, c=0.03, mean=-1.1, std=2.0, lamda=0.01, sample_type='lognormal', weight_type='adaptive'):
        self.ts = None
        self.c = c
        self.P_mean = mean
        self.P_std = std
        self.lamda = lamda
        self.sample_type = sample_type
        self.weight_type = weight_type

    def karras_schedule(
            self,
            num_timesteps,
            sigma_min=0.002,
            sigma_max=80.0,
            rho=7.0,
            device=None
    ):
        rho_inv = 1.0 / rho
        steps = torch.arange(num_timesteps, device=device) / max(num_timesteps - 1, 1)
        sigmas = sigma_min ** rho_inv + steps * (
                sigma_max ** rho_inv - sigma_min ** rho_inv
        )
        sigmas = sigmas ** rho
        return sigmas

    def uniform_timestep_distribution(
            self,
            num_samples,
            sigmas
    ):
        timesteps = torch.randint(0, sigmas.shape[0]-1, (num_samples,))
        return timesteps

    def update(self, net=None, dataset_iterator=None, device=None, ode_fn = None):
        a = 1
        print(a)

    def __call__(self, net, images, eps, ode_fn=None, labels=None, augment_labels=None, teacher=None, ratio=None):
        term = ratio * ((150 + 1) ** 2 - 2 ** 2) + 2 ** 2 - 1
        num_timesteps = math.ceil(math.sqrt(term)) + 1
        sigmas = self.karras_schedule(num_timesteps=num_timesteps, device=images.device)

        timesteps = self.uniform_timestep_distribution(images.shape[0], sigmas).to(images.device)
        t = sigmas[timesteps + 1].clone()[:, None, None, None]
        r = sigmas[timesteps].clone()[:, None, None, None]

        yt = ode_fn(images=images, t=t, eps=eps)
        yr = ode_fn(images=images, t=r, eps=eps)
        rng_state = torch.cuda.get_rng_state()
        D_yt = net(yt, t, labels, augment_labels=augment_labels)
        torch.cuda.set_rng_state(rng_state)
        with torch.no_grad():
            D_yr = net(yr, r, labels, augment_labels=augment_labels)
        loss = (D_yt - D_yr) ** 2
        loss = torch.sum(loss.reshape(loss.shape[0], -1), dim=-1)
        loss = torch.sqrt(loss)
        return loss, ((t - r) / t).mean()




@persistence.persistent_class
class loss_CD:
    def __init__(self, c=0.03, mean=-1.1, std=2.0, lamda=0.01, sample_type='lognormal', weight_type='adaptive'):
        self.ts = None
        self.c = c
        self.P_mean = mean
        self.P_std = std
        self.lamda = lamda
        self.sample_type = sample_type
        self.weight_type = weight_type

    def karras_schedule(
            self,
            num_timesteps,
            sigma_min=0.002,
            sigma_max=80.0,
            rho=7.0,
            device=None
    ):
        rho_inv = 1.0 / rho
        steps = torch.arange(num_timesteps, device=device) / max(num_timesteps - 1, 1)
        sigmas = sigma_min ** rho_inv + steps * (
                sigma_max ** rho_inv - sigma_min ** rho_inv
        )
        sigmas = sigmas ** rho
        return sigmas

    def uniform_timestep_distribution(
            self,
            num_samples,
            sigmas
    ):
        timesteps = torch.randint(0, sigmas.shape[0]-1, (num_samples,))
        return timesteps

    def update(self, net=None, dataset_iterator=None, device=None, ode_fn = None):
        a = 1
        print(a)

    def __call__(self, net, images, eps, ode_fn=None, labels=None, augment_labels=None, teacher=None, ratio=None):
        sigmas = self.karras_schedule(num_timesteps=18, device=images.device)

        timesteps = self.uniform_timestep_distribution(images.shape[0], sigmas).to(images.device)
        t = sigmas[timesteps + 1].clone()[:, None, None, None]
        r = sigmas[timesteps].clone()[:, None, None, None]

        yt = ode_fn(images=images, t=t, eps=eps)
        predict_x0 = teacher(yt, t)
        d = (yt - predict_x0) / t
        pyr = yt + (r - t) * d
        predict_x0 = teacher(pyr, r)
        next_d = (pyr - predict_x0) / r
        yr = yt + (r - t) * (d + next_d) / 2.0

        rng_state = torch.cuda.get_rng_state()
        D_yt = net(yt, t, labels, augment_labels=augment_labels)
        torch.cuda.set_rng_state(rng_state)
        with torch.no_grad():
            D_yr = net(yr, r, labels, augment_labels=augment_labels)

        loss = (D_yt - D_yr) ** 2
        loss = torch.sum(loss.reshape(loss.shape[0], -1), dim=-1)
        loss = torch.sqrt(loss)
        return loss, ((t - r) / t).mean()
    


@persistence.persistent_class
class loss_ECM:
    def __init__(self, c=0.03, mean=-1.1, std=2.0, lamda=0.01, sample_type='lognormal', weight_type='adaptive'):
        self.ts = None
        self.c = c
        self.P_mean = mean
        self.P_std = std
        self.lamda = lamda
        self.sample_type = sample_type
        self.weight_type = weight_type
        self.stage = 0

    def update(self, net=None, dataset_iterator=None, device=None, ode_fn = None):
        self.stage += 1
        print(self.stage)

    def t_to_r_sigmoid(self, t):
        adj = 1 + 8.0 * torch.sigmoid(-t)
        decay = 1 / 2 ** (self.stage)
        ratio = 1 - decay * adj
        r = t * ratio
        return torch.clamp(r, min=0)

    def __call__(self, net, images, eps, ode_fn=None, labels=None, augment_labels=None, teacher=None, ratio=None):
        rnd_normal = torch.randn([images.shape[0], 1, 1, 1], device=images.device)
        t = (rnd_normal * self.P_std + self.P_mean).exp()
        if ode_fn.ode_type == 'rflow':
            t = t / (1.0 +t)
        if ode_fn.ode_type == 've':
            r = self.t_to_r_sigmoid(t)
        else:
            r = self.t_to_r_sigmoid(t / (1-t))
            r = r / (1+r)
        yt = ode_fn(images=images, t=t, eps=eps)
        yr = ode_fn(images=images, t=r, eps=eps)
        rng_state = torch.cuda.get_rng_state()
        D_yt = net(yt, t, labels, augment_labels=augment_labels)
        if r.max() > 0:
            torch.cuda.set_rng_state(rng_state)
            with torch.no_grad():
                D_yr = net(yr, r, labels, augment_labels=augment_labels)
            mask = r > 0
            D_yr = torch.nan_to_num(D_yr)
            D_yr = mask * D_yr + (~mask) * images
        else:
            D_yr = images
        loss = (D_yt - D_yr) ** 2
        loss = torch.sum(loss.reshape(loss.shape[0], -1), dim=-1)
        loss = torch.sqrt(loss)

        if ode_fn.ode_type == 'rflow':
            t = t / (1.0 -t)
            r = r / (1.0 -r)

        return loss / (t-r).flatten(), ((t - r) / t).mean()
    


@persistence.persistent_class
class loss_ICT:
    def __init__(self, c=0.03, mean=-1.1, std=2.0, lamda=0.01, sample_type='lognormal', weight_type='adaptive'):
        self.ts = None
        self.c = c
        self.P_mean = mean
        self.P_std = std
        self.lamda = lamda
        self.sample_type = sample_type
        self.weight_type = weight_type
        self.stage = -1

    def update(self, net=None, dataset_iterator=None, device=None, ode_fn = None):
        self.stage += 1
        print(self.stage)

    def improved_timesteps_schedule(
            self,
            initial_timesteps = 10,
            final_timesteps = 1280
    ):
        num_timesteps = initial_timesteps * math.pow(
            2, self.stage
        )
        num_timesteps = min(num_timesteps, final_timesteps) + 1

        return num_timesteps

    def karras_schedule(
            self,
            num_timesteps,
            sigma_min = 0.002,
            sigma_max = 80.0,
            rho = 7.0,
            device = None
    ):
        rho_inv = 1.0 / rho
        steps = torch.arange(num_timesteps, device=device) / max(num_timesteps - 1, 1)
        sigmas = sigma_min ** rho_inv + steps * (
                sigma_max ** rho_inv - sigma_min ** rho_inv
        )
        sigmas = sigmas ** rho
        return sigmas

    def lognormal_timestep_distribution(
            self,
            num_samples,
            sigmas
    ):
        pdf = torch.erf((torch.log(sigmas[1:]) - self.P_mean) / (self.P_std * math.sqrt(2))) - torch.erf(
            (torch.log(sigmas[:-1]) - self.P_mean) / (self.P_std * math.sqrt(2))
        )
        pdf = pdf / pdf.sum()
        timesteps = torch.multinomial(pdf, num_samples, replacement=True)
        return timesteps


    def __call__(self, net, images, eps, ode_fn=None, labels=None, augment_labels=None, teacher=None, ratio=None):
        num_timesteps = self.improved_timesteps_schedule()
        sigmas = self.karras_schedule(num_timesteps=num_timesteps, device=images.device)

        timesteps = self.lognormal_timestep_distribution(images.shape[0], sigmas).to(images.device)
        t = sigmas[timesteps + 1].clone()[:, None, None, None]
        r = sigmas[timesteps].clone()[:, None, None, None]
        
        if ode_fn.ode_type == 'rflow':
            t = t / (1.0 + t)
            r = r / (1.0 + r)

        yt = ode_fn(images=images, t=t, eps=eps)
        yr = ode_fn(images=images, t=r, eps=eps)
        rng_state = torch.cuda.get_rng_state()
        D_yt = net(yt, t, labels, augment_labels=augment_labels)
        torch.cuda.set_rng_state(rng_state)
        with torch.no_grad():
            D_yr = net(yr, r, labels, augment_labels=augment_labels)
        loss = (D_yt - D_yr) ** 2
        loss = torch.sum(loss.reshape(loss.shape[0], -1), dim=-1)
        loss = torch.sqrt(loss)

        return loss / (t-r).flatten(), ((t - r) / t).mean()