import torch
from guided_diffusion.script_util import create_classifier
import numpy as np
from cm import logger
import torch.nn as nn
from ldm.modules.diffusionmodules.model import Decoder

class decoder_(nn.Module):
    def __init__(self, embed_dim=4, z_channels=4):
        super().__init__()
        self.first_stage_model = decoder__(embed_dim, z_channels)

    def forward(self, z, teacher=False, mult=1.0, no_grad=True):
        return self.first_stage_model(z, teacher=teacher, mult=mult, no_grad=no_grad)


class decoder__(nn.Module):
    def __init__(self, embed_dim=4, z_channels=4):
        super().__init__()
        self.decoder = Decoder()
        self.post_quant_conv = nn.Conv2d(embed_dim, z_channels, 1)
        for param in self.decoder.parameters():
            param.requires_grad_(False)
        for param in self.post_quant_conv.parameters():
            param.requires_grad_(False)

    def forward(self, z, num=4, no_grad=True):
        z /= 0.24578019976615906
        micro_batch_size = (z.shape[0] - 1) // num + 1
        if no_grad:
            x_samples_ddim = torch.tensor([], device=z.device)
            with torch.no_grad():
                for k in range(num):
                    x_samples_ddim_temp = self.decoder(
                        self.post_quant_conv(z[k * micro_batch_size:(k + 1) * micro_batch_size]))
                    x_samples_ddim = torch.cat((x_samples_ddim, x_samples_ddim_temp))
                    del x_samples_ddim_temp
        else:
            x_samples_ddim = self.decoder(self.post_quant_conv(z))
        return torch.clamp(x_samples_ddim, min=-1.0, max=1.0)

def get_discriminator(latent_extractor_ckpt, discriminator_ckpt, condition, img_resolution=32, device='cuda', enable_grad=True):
    classifier = load_classifier(latent_extractor_ckpt, img_resolution, device, eval=True)
    discriminator = load_discriminator(discriminator_ckpt, device, condition, eval=True)
    def evaluate(perturbed_inputs, timesteps=None, condition=None):
        with torch.enable_grad() if enable_grad else torch.no_grad():
            adm_features = classifier(perturbed_inputs, timesteps=timesteps, feature=True)
            prediction = discriminator(adm_features, timesteps, sigmoid=True, condition=condition).view(-1)
        return prediction
    return evaluate

def load_classifier(args, eval=True):
    latent = args.data_name in ['church']
    classifier_args = dict(
      image_size=args.image_size,
      classifier_use_fp16=args.use_fp16,
      classifier_width=128,
      classifier_depth=(4 if args.image_size in [64, 32] else 2) if not latent else 2,
      classifier_attention_resolutions="32,16,8",
      classifier_use_scale_shift_norm=True,
      classifier_resblock_updown=True,
      classifier_pool=args.classifier_pool,
      out_channels=args.out_channels,
      in_channels = (4 if latent else 3)
    )
    if latent:
        classifier = create_classifier(feature_aggregated=args.feature_aggregated, channel_mult=[1,1,1,1,1], **classifier_args)
    else:
        classifier = create_classifier(**classifier_args)
    if args.classifier_model_path is not None:
        logger.log(f"loading classifier model from checkpoint: {args.classifier_model_path}...")
        classifier_state = torch.load(args.classifier_model_path, map_location="cpu")
        classifier.load_state_dict(classifier_state)
    if eval:
      classifier.eval()
      classifier.requires_grad_(False)
    return classifier

def load_discriminator(args, ckpt_path, channel_mult=[1,1,1],
                       classifier_width=128, classifier_depth=2, classifier_attention_resolutions="32,16,8"):
    if args.data_name in ['church']:
        if args.discriminator_input == 'latent':
            input_size = args.image_size
        elif args.discriminator_input == 'feature':
            input_size = args.d_out_res
        else:
            raise NotImplementedError
    elif args.data_name == 'cifar10':
        if args.discriminator_input == 'data':
            input_size = 32
        elif args.discriminator_input == 'lpips_feature':
            input_size = 28
    elif args.data_name == 'in64':
        input_size = 64
    if args.d_architecture == 'unet':
        try:
            channel_mult = [int(x) for x in channel_mult.split(',')]
        except:
            pass
        discriminator_args = dict(
          image_size=input_size,
          classifier_use_fp16=args.use_d_fp16,
          classifier_width=classifier_width,
          classifier_depth=classifier_depth,
          classifier_attention_resolutions=classifier_attention_resolutions,
          classifier_use_scale_shift_norm=True,
          classifier_resblock_updown=True,
          classifier_pool="attention",
          channel_mult=channel_mult,
          out_channels=1,
          in_channels=args.discriminator_input_channel,
          condition=args.class_cond,
          embed=args.embed,
        )

        print("discriminator use fp16: ", args.use_d_fp16)
        discriminator = create_classifier(**discriminator_args)
        if ckpt_path is not None:
            #ckpt_path = os.getcwd() + ckpt_path
            logger.log(f"loading discriminator model from checkpoint: {ckpt_path}...")
            discriminator_state = torch.load(ckpt_path, map_location="cpu")
            discriminator.load_state_dict(discriminator_state)
    elif args.d_architecture == 'ddgan':
        from score_sde.models.discriminator import Discriminator_small, Discriminator_large
        if args.image_size == 32:
            discriminator = Discriminator_small(nc=args.discriminator_input_channel, act=torch.nn.LeakyReLU(0.2),
                                                embed=args.embed)
        else:
            raise NotImplementedError
    elif args.d_architecture == 'stylegan2':
        pass
    elif args.d_architecture == 'stylegan-ada':
        import pickle
        print(f'Loading discriminator network from "{args.discriminator_checkpoint}"...')
        with open(args.discriminator_checkpoint, 'rb') as f:
            discriminator = pickle.load(f)['D']
        #print(discriminator)
    elif args.d_architecture == 'stylegan-xl':
        from pg_modules.projector import F_RandomProj
        from pg_modules.discriminator import MultiScaleD
        backbones = ['deit_base_distilled_patch16_224', 'tf_efficientnet_lite0']
        feature_networks, discriminators = [], []
        backbone_kwargs = {'im_res': args.image_size}
        for i, bb_name in enumerate(backbones):
            feat = F_RandomProj(bb_name, **backbone_kwargs)
            disc = MultiScaleD(
                channels=feat.CHANNELS,
                resolutions=feat.RESOLUTIONS,
                **backbone_kwargs,
            )
            feature_networks.append([bb_name, feat])
            discriminators.append([bb_name, disc])
        feature_networks = nn.ModuleDict(feature_networks)
        feature_networks = feature_networks.train(False)
        feature_networks.requires_grad_(False)
        discriminators = nn.ModuleDict(discriminators)
        return discriminators, feature_networks
    return discriminator

def get_grad_log_ratio(discriminator, vpsde, unnormalized_input, std_wve_t, img_resolution, time_min, time_max, class_labels, log=False):
    mean_vp_tau, tau = vpsde.transform_unnormalized_wve_to_normalized_vp(std_wve_t) ## VP pretrained classifier
    if tau.min() > time_max or tau.min() < time_min or discriminator == None:
        if log:
          return torch.zeros_like(unnormalized_input), 10000000. * torch.ones(unnormalized_input.shape[0], device=unnormalized_input.device)
        return torch.zeros_like(unnormalized_input)
    else:
        input = mean_vp_tau[:,None,None,None] * unnormalized_input
    with torch.enable_grad():
        x_ = input.float().clone().detach().requires_grad_()
        if img_resolution == 64: # ADM trained UNet classifier for 64x64 with Cosine VPSDE
            tau = vpsde.compute_t_cos_from_t_lin(tau)
        tau = torch.ones(input.shape[0], device=tau.device) * tau
        log_ratio = get_log_ratio(discriminator, x_, tau, class_labels)
        discriminator_guidance_score = torch.autograd.grad(outputs=log_ratio.sum(), inputs=x_, retain_graph=False)[0]
        # print(mean_vp_tau.shape)
        # print(std_wve_t.shape)
        # print(discriminator_guidance_score.shape)
        discriminator_guidance_score *= - ((std_wve_t[:,None,None,None] ** 2) * mean_vp_tau[:,None,None,None])
    if log:
      return discriminator_guidance_score, log_ratio
    return discriminator_guidance_score

def get_log_ratio(discriminator, input, time, class_labels):
    if discriminator == None:
        return torch.zeros(input.shape[0], device=input.device)
    else:
        logits = discriminator(input, timesteps=time, condition=class_labels)
        prediction = torch.clip(logits, 1e-5, 1. - 1e-5)
        log_ratio = torch.log(prediction / (1. - prediction))
        return log_ratio

class vpsde():
    def __init__(self, beta_min=0.1, beta_max=20., multiplier=1., cos_t_classifier=False,):
        self.beta_0 = beta_min
        self.beta_1 = beta_max
        self.multiplier = multiplier
        self.a = (self.beta_1 ** 0.5 - self.beta_0 ** 0.5) ** 2 / 3.
        self.b = (self.beta_0 ** 0.5) * (self.beta_1 ** 0.5 - self.beta_0 ** 0.5)
        self.c = self.beta_0
        self.s = 0.008
        self.f_0 = np.cos(self.s / (1. + self.s) * np.pi / 2.) ** 2
        self.cos_t_classifier = cos_t_classifier

    @property
    def T(self):
        return 1

    def compute_tau(self, std_wve_t, multiplier=-1.):
        if multiplier == -1:
            if self.multiplier == 1.:
                tau = -self.beta_0 + torch.sqrt(self.beta_0 ** 2 + 2. * (self.beta_1 - self.beta_0) * torch.log(1. + std_wve_t ** 2))
                tau /= self.beta_1 - self.beta_0
            elif self.multiplier == 2.:
                d = - torch.log(1. + std_wve_t ** 2)
                in_ = (2 * (self.b ** 3) - 9 * self.a * self.b * self.c + 27. * (self.a ** 2) * d) ** 2 - 4 * (((self.b ** 2) - 3 * self.a * self.c) ** 3)
                out_ = 2 * (self.b ** 3) - 9 * self.a * self.b * self.c + 27. * (self.a ** 2) * d
                plus = (out_ + in_ ** 0.5)
                minus = (out_ - in_ ** 0.5)
                sign_plus = torch.sign(plus)
                sign_minus = torch.sign(minus)
                tau = - self.b / (3. * self.a) - sign_plus * ((torch.abs(plus) / 2.) ** (1/3.)) / (3. * self.a) - sign_minus * ((torch.abs(minus) / 2.) ** (1/3.)) / (3. * self.a)
        elif multiplier == 1.:
            tau = -self.beta_0 + torch.sqrt(self.beta_0 ** 2 + 2. * (self.beta_1 - self.beta_0) * torch.log(1. + std_wve_t ** 2))
            tau /= self.beta_1 - self.beta_0
        elif multiplier == 2.:
            d = - torch.log(1. + std_wve_t ** 2)
            in_ = (2 * (self.b ** 3) - 9 * self.a * self.b * self.c + 27. * (self.a ** 2) * d) ** 2 - 4 * (
                        ((self.b ** 2) - 3 * self.a * self.c) ** 3)
            out_ = 2 * (self.b ** 3) - 9 * self.a * self.b * self.c + 27. * (self.a ** 2) * d
            plus = (out_ + in_ ** 0.5)
            minus = (out_ - in_ ** 0.5)
            sign_plus = torch.sign(plus)
            sign_minus = torch.sign(minus)
            tau = - self.b / (3. * self.a) - sign_plus * ((torch.abs(plus) / 2.) ** (1 / 3.)) / (
                        3. * self.a) - sign_minus * ((torch.abs(minus) / 2.) ** (1 / 3.)) / (3. * self.a)
        return tau

    def marginal_prob(self, t, multiplier=-1.):
        log_mean_coeff = - 0.5 * self.integral_beta(t, multiplier)
        #log_mean_coeff = -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
        mean = torch.exp(log_mean_coeff)
        std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff))
        return mean, std

    def transform_normalized_vp_to_unnormalized_wve(self, t, multiplier=-1.):
        mean, std = self.marginal_prob(t, multiplier=multiplier)
        return std / mean

    def sampling_std(self, num_step):
        #c = 1000 // num_step
        assert 1000 % num_step == 0
        ddim_timesteps = torch.from_numpy(np.array(list(range(0, 1000, 1000 // num_step)))[::-1].copy())
        print(ddim_timesteps)
        steps_out = ddim_timesteps + 1
        std = self.transform_normalized_vp_to_unnormalized_wve(steps_out / 1000.)
        print(std)
        return std

    def transform_unnormalized_wve_to_normalized_vp(self, t, std_out=False, multiplier=-1.):
        tau = self.compute_tau(t, multiplier=multiplier)
        mean_vp_tau, std_vp_tau = self.marginal_prob(tau, multiplier=multiplier)
        #print("tau before: ", tau)
        if self.cos_t_classifier:
            tau = self.compute_t_cos_from_t_lin(tau)
        #print("tau after: ", tau)
        if std_out:
            return mean_vp_tau, std_vp_tau, tau
        return mean_vp_tau, tau

    def from_rescaled_t_to_original_std(self, rescaled_t):
        return torch.exp(rescaled_t / 250.) - 1e-44

    def compute_t_cos_from_t_lin(self, t_lin):
        sqrt_alpha_t_bar = torch.exp(-0.25 * t_lin ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t_lin * self.beta_0)
        time = torch.arccos(np.sqrt(self.f_0) * sqrt_alpha_t_bar)
        t_cos = self.T * ((1. + self.s) * 2. / np.pi * time - self.s)
        return t_cos

    def get_diffusion_time(self, batch_size, batch_device, t_min=1e-5, importance_sampling=True):
        if importance_sampling:
            Z = self.normalizing_constant(t_min)
            u = torch.rand(batch_size, device=batch_device)
            return (-self.beta_0 + torch.sqrt(self.beta_0 ** 2 + 2 * (self.beta_1 - self.beta_0) *
                    torch.log(1. + torch.exp(Z * u + self.antiderivative(t_min))))) / (self.beta_1 - self.beta_0), Z.detach()
        else:
            return torch.rand(batch_size, device=batch_device) * (self.T - t_min) + t_min, 1

    def antiderivative(self, t, stabilizing_constant=0.):
        if isinstance(t, float) or isinstance(t, int):
            t = torch.tensor(t).float()
        return torch.log(1. - torch.exp(- self.integral_beta(t)) + stabilizing_constant) + self.integral_beta(t)

    def normalizing_constant(self, t_min):
        return self.antiderivative(self.T) - self.antiderivative(t_min)

    def integral_beta(self, t, multiplier=-1.):
        if multiplier == -1.:
            if self.multiplier == 1.:
                return 0.5 * t ** 2 * (self.beta_1 - self.beta_0) + t * self.beta_0
            elif self.multiplier == 2.:
                return ((self.beta_1 ** 0.5 - self.beta_0 ** 0.5) ** 2) * (t ** 3) / 3. \
                      + (self.beta_0 ** 0.5) * (self.beta_1 ** 0.5 - self.beta_0 ** 0.5) * (t ** 2) + self.beta_0 * t
        elif multiplier == 1.:
            return 0.5 * t ** 2 * (self.beta_1 - self.beta_0) + t * self.beta_0
        elif multiplier == 2.:
            return ((self.beta_1 ** 0.5 - self.beta_0 ** 0.5) ** 2) * (t ** 3) / 3. \
                + (self.beta_0 ** 0.5) * (self.beta_1 ** 0.5 - self.beta_0 ** 0.5) * (t ** 2) + self.beta_0 * t