import os
import torch
import pytorch_lightning as pl
from omegaconf import OmegaConf
from torch.nn import functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
from copy import deepcopy
from einops import rearrange
import numpy as np

from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel
from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config

__models__ = {
    'class_label': EncoderUNetModel,
    'segmentation': UNetModel
}

def disabled_train(self, mode=True):
    """Overwrite model.train with this function to make sure train/eval mode
    does not change anymore."""
    return self


class NoisyLatentImageClassifier(pl.LightningModule):

    def __init__(self,
                 first_stage_config,
                 unet_config,
                 num_classes=1000,
                 scheduler_config=None,
                 weight_decay=1.e-2,
                 *args,
                 **kwargs):
        super().__init__(*args, **kwargs)
        self.num_classes = num_classes
        # get latest config of diffusion model
        self.vp = VPSDE()
        self.instantiate_first_stage(first_stage_config)
        self.classifier_model = instantiate_from_config(unet_config)
        self.scheduler_config = scheduler_config
        self.use_scheduler = self.scheduler_config is not None
        self.weight_decay = weight_decay

    def instantiate_first_stage(self, config):
        model = instantiate_from_config(config)
        self.first_stage_model = model.eval()
        for param in self.first_stage_model.parameters():
            param.requires_grad = False

    def forward(self, x_noisy, t):
        return self.classifier_model(x_noisy, t)

    @torch.no_grad()
    def get_input(self, batch):
        x = batch['image'].to(self.device)
        if len(x.shape) == 3:
            x = x[..., None]
        x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
        encoder_posterior = self.first_stage_model.encode(x)
        z = self.get_first_stage_encoding(encoder_posterior).detach()

        return z

    def get_first_stage_encoding(self, encoder_posterior):
        if isinstance(encoder_posterior, DiagonalGaussianDistribution):
            z = encoder_posterior.sample()
        elif isinstance(encoder_posterior, torch.Tensor):
            z = encoder_posterior
        else:
            raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
        return z

    def shared_step(self, batch, t=None):
        encoded_latent = self.get_input(batch)
        #if t == None:
        #    t = torch.rand(encoded_latent.shape[0], device=self.device) * (1. - 1e-5) + 1e-5
        t = torch.ones_like(batch['class_label']) * 1e-3
        z = torch.randn_like(encoded_latent)
        mean, std = self.vp.marginal_prob(encoded_latent, t)
        perturbed_latent = mean + std[:, None, None, None] * z
        t = 1000. * t
        logits = self.forward(perturbed_latent, t)
        targets = batch['class_label'].to(self.device)
        loss = F.cross_entropy(logits, targets, reduction='none')
        #print(torch.softmax(logits, 1).shape)
        #print("predict: ", torch.max(torch.softmax(logits, 1), 1))
        #print("target: ", targets)
        #print(t[0].item(), loss[0].item())
        loss = loss.mean()

        '''with torch.no_grad():
            small_t = torch.ones_like(batch['class_label']) * 1e-3
            small_z = torch.randn_like(encoded_latent)
            small_mean, small_std = self.vp.marginal_prob(encoded_latent, small_t)
            small_perturbed_latent = small_mean + small_std[:, None, None, None] * small_z
            small_t = 1000. * small_t
            small_logits = self.forward(small_perturbed_latent, small_t)
            small_targets = batch['class_label'].to(self.device)
            small_loss = F.cross_entropy(small_logits, small_targets, reduction='none')'''

        #print("loss: ", loss)
        print("predict: ", torch.max(torch.softmax(logits, 1), 1))
        print("target: ", targets)


        return loss

    def training_step(self, batch, batch_idx):
        loss = self.shared_step(batch)
        return loss

    def validation_step(self, batch, batch_idx):
        small_t = torch.ones_like(batch['class_label']) * 1e-3
        small_t_loss = self.shared_step(batch, t=small_t)
        large_t = torch.ones_like(batch['class_label']) * 9e-1
        large_t_loss = self.shared_step(batch, t=large_t)

        self.log("val/small_t_loss", small_t_loss)
        self.log("val/large_t_loss", large_t_loss)
        print("small t loss: ", small_t_loss)
        print("large t loss: ", large_t_loss)
        print("predict: ", torch.max(torch.softmax(logits, 1), 1))
        print("target: ", targets)
        return self.log

    def configure_optimizers(self):
        optimizer = AdamW(self.classifier_model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)

        if self.use_scheduler:
            scheduler = instantiate_from_config(self.scheduler_config)

            print("Setting up LambdaLR scheduler...")
            scheduler = [
                {
                    'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule),
                    'interval': 'step',
                    'frequency': 1
                }]
            return [optimizer], scheduler

        return optimizer


class VPSDE():
  def __init__(self, beta_min=0.1, beta_max=20):
    """Construct a Variance Preserving SDE.

    Args:
      beta_min: value of beta(0)
      beta_max: value of beta(1)
      N: number of discretization steps
    """
    self.beta_0 = beta_min
    self.beta_1 = beta_max

  def marginal_prob(self, x, t):
    log_mean_coeff = -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
    mean = torch.exp(log_mean_coeff[:, None, None, None]) * x
    std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff))
    return mean, std
