import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'
import sys

import torch
from sklearn import datasets
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.functional import F
import numpy as np
import copy
from torchvision.datasets import FashionMNIST, MNIST, CIFAR10
from torchvision import transforms
from sklearn.preprocessing import PolynomialFeatures
from diffusers import UNet2DModel
import torchvision
from PIL import Image
from torch.utils.tensorboard import SummaryWriter
import torch.utils
from  torch.optim.lr_scheduler import ReduceLROnPlateau


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def clip_noise_schedule(alphas2, clip_value=0.001):
    """
    For a noise schedule given by alpha^2, this clips alpha_t / alpha_t-1. This may help improve stability during
    sampling.
    """
    # alphas2 = np.concatenate([np.ones(1), alphas2], axis=0)

    alphas_step = alphas2[1:] / alphas2[:-1]

    alphas_step = np.clip(alphas_step, a_min=clip_value, a_max=1.0)
    alphas2 = np.cumprod(alphas_step, axis=0)

    return alphas2

def polynomial_decay(timesteps: int, s=1e-5, clip_value=0.001, power=2.0):
    """
    A noise schedule based on a simple polynomial equation: 1 - x^power.
    from https://arxiv.org/abs/2203.17003
    """

    # steps = timesteps+1
    # x = np.linspace(0, steps, steps)
    steps = timesteps
    x = np.linspace(0, steps, steps + 1, dtype=np.float64)

    alphas2 = (1 - np.power(x / steps, power)) ** 2

    alphas2 = clip_noise_schedule(alphas2, clip_value=clip_value)

    precision = 1 - 2 * s

    alphas_bar = precision * alphas2 + s

    return alphas_bar

def cosine_decay(timesteps, s: float = 0.008, clip_value: float = 0.999):
    """
    cosine schedule with clipping
    from https://arxiv.org/abs/2102.09672
    """
    steps = np.linspace(0, timesteps, timesteps + 1, dtype=np.float64)
    f_t = np.cos((steps / timesteps + s) / (1 + s) * np.pi * 0.5) ** 2
    alphas_bar = f_t / f_t[0]

    # clipping trick for more stable noise schedule
    betas = 1 - (alphas_bar[1:] / alphas_bar[:-1])
    betas = np.clip(betas, a_min=0.0, a_max=clip_value)

    # recompute stable alphas
    alphas = 1 - betas
    alphas_bar = np.cumprod(alphas, axis=0)

    return alphas_bar

def linear_decay(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    alpha = 1-np.linspace(beta_start, beta_end, 1000, dtype=np.float64)
    return np.cumprod(alpha)

def diffuse(x, sqrt_alpha_bar, sqrt_beta_bar, eps=None):
    if eps is None:
        eps = torch.randn_like(x)
    return sqrt_alpha_bar * x + sqrt_beta_bar * eps, eps


def update_ema_variables(model, ema_model, ema_decay=0.999):
    with torch.no_grad():
        for ema_param, param in zip(ema_model.parameters(), model.parameters()):
            ema_param.data.copy_(ema_param.data * ema_decay + (1 - ema_decay) * param.data)
            
            
def predict_time_noise(noisy_bat, model, time_model):
    diff_step_pred = time_model(noisy_bat, torch.zeros(noisy_bat.shape[0], dtype=torch.int32).to(device), return_dict=False)[0]

    diff_step_round = diff_step_pred.detach()
    diff_step_round = diff_step_round.mean((1,2,3))
    diff_step_round *= T
    diff_step_round = torch.clamp(diff_step_round, min=0, max=T-1).long()

    eps_pred = model(noisy_bat, diff_step_round.to(device), return_dict=False)[0]

    return eps_pred, diff_step_pred

def _step(bat, model , time_model=None):
        diff_step = torch.randint(0, T, (bat.shape[0],)).to(device)
        sqrt_alpha_bar = sqrt_alphas_bar[diff_step][:, None, None, None]
        sqrt_beta_bar = sqrt_betas_bar[diff_step][:, None, None, None]

        noisy_bat, eps_true = diffuse(bat, sqrt_alpha_bar, sqrt_beta_bar)
        
        if time_model is None:
            eps_pred = model(noisy_bat, diff_step.to(device), return_dict=False)[0]
        else:
            eps_pred, diff_step_pred = predict_time_noise(noisy_bat, model, time_model)

        noise_loss = F.mse_loss(eps_pred, eps_true)
        
        if time_model is None:
            time_loss = torch.zeros(1,)
            loss = noise_loss.clone()
        else:
            time_loss = F.mse_loss(diff_step_pred, torch.ones(diff_step_pred.shape).to(device) * (diff_step.float()/T)[:,None,None,None] )
            loss = 0.5 * noise_loss + 0.5 * time_loss
        
        loss_dict = {
            'loss': loss,
            'time_loss': time_loss,
            'noise_loss': noise_loss,
            }
        
        return loss_dict

    
    
def train_loop(train_loader, val_loader, optimizer, model, time_model=None, epochs=10, exp='mnist_cond', scheduler=None):
        
    
    os.makedirs(exp + '_models', exist_ok=True)
    os.makedirs(exp + '_imgs', exist_ok=True)
    
    writer = SummaryWriter() 
    
    model.to(device)
    ema_model = copy.deepcopy(model).to(device)
    
    if time_model is not None:
        time_model.to(device)
        ema_time_model = copy.deepcopy(time_model).to(device)
    else:
        ema_time_model = None

    n_iter = 0
    current_best_loss = float('inf')
    previous_loss = float('inf')
    early_stop_counter = 0
    
    for epoch in range(epochs):

        sample_img = next(iter(val_loader))[0].to(device)
        sample_noise = torch.randn_like(sample_img)
        save_images(sample_img, f'orig.png')

        if epoch in range(0,1000,50):

                noisy_img_list = []
                recon_img_list = []
                for diff_step in range(0,1000,100):
                    noisy_samp = diffuse(sample_img, sqrt_alphas_bar[diff_step], sqrt_betas_bar[diff_step], eps=sample_noise)[0]
                    noisy_img_list.append(noisy_samp)
                    
                    if time_model is None:
                        noise_pred = ema_model(noisy_samp, torch.ones(noisy_samp.shape[0], dtype=torch.int32).to(device) * diff_step, return_dict=False)[0]
                    else:
                        noise_pred, diff_step_pred = predict_time_noise(noisy_samp, ema_model, ema_time_model)
                    noise_pred = noise_pred.detach()
                    
                    recon_img_list.append((noisy_samp - sqrt_betas_bar[diff_step] * noise_pred) / sqrt_alphas_bar[diff_step])

                save_images(torch.cat(noisy_img_list, 0), f"{exp + '_imgs'}/noisy_images_{epoch}.png", nrow=sample_img.shape[0])
                #img_grid = torchvision.utils.make_grid(torch.cat(noisy_img_list, 0))
                #writer.add_image(f'noisy_images_{epoch}', img_grid)
                save_images(torch.cat(recon_img_list, 0), f"{exp + '_imgs'}/reconstructed_{epoch}.png", nrow=sample_img.shape[0])
                #img_grid = torchvision.utils.make_grid(torch.cat(recon_img_list, 0))
                #writer.add_image(f'reconstructed_{epoch}', img_grid)
                

        progress_bar = tqdm(total=len(train_loader), position=0, leave=True)
        progress_bar.set_description(f"Epoch {epoch} - train")
        for step, bat in enumerate(train_loader):

            bat = bat[0].to(device)

            model = model.train()
            
            if time_model is not None:
                time_model = time_model.train()

            optimizer.zero_grad()

            loss_dict = _step(bat, model, time_model)
                
            writer.add_scalar('Loss/train', loss_dict['loss'], n_iter)
            writer.add_scalar('Noise_loss/train', loss_dict['noise_loss'], n_iter)
            writer.add_scalar('Time_loss/train', loss_dict['time_loss'], n_iter)

            loss = loss_dict['loss']

            loss.backward()
            optimizer.step()

            update_ema_variables(model, ema_model)
            if time_model is not None:
                update_ema_variables(time_model, ema_time_model)
            
            progress_bar.update(1)
            n_iter += 1
            
            logs = {"loss": loss.detach().item(),
                    "noise_loss": loss_dict['noise_loss'].detach().item(),
                    "time_loss": loss_dict['time_loss'].detach().item()}
            progress_bar.set_postfix(**logs)

        losses = []
        progress_bar = tqdm(total=len(val_loader), position=0, leave=True)
        progress_bar.set_description(f"Epoch {epoch} - val")
        for bat in val_loader:
            bat = bat[0].to(device)
            ema_model.eval()
            if ema_time_model is not None:
                ema_time_model.eval()
            with torch.no_grad():
                loss_dict = _step(bat, ema_model, ema_time_model)
                loss = loss_dict['loss']
                
                writer.add_scalar('Loss/test', loss_dict['loss'], n_iter)
                writer.add_scalar('Noise_loss/test', loss_dict['noise_loss'], n_iter)
                writer.add_scalar('Time_loss/test', loss_dict['time_loss'], n_iter)
                
                losses.append(loss.detach().item())

            progress_bar.update(1)
            n_iter += 1
            
            logs = {"loss": loss.detach().item(),
                    "noise_loss": loss_dict['noise_loss'].detach().item(),
                    "time_loss": loss_dict['time_loss'].detach().item()}
            progress_bar.set_postfix(**logs)


        val_loss = np.mean(losses)

        if scheduler is not None:
            scheduler.step(val_loss)
        
        # save model
        if val_loss < current_best_loss:
            current_best_loss = val_loss
            print('new best model at epoch: ', epoch, ' with loss: ', val_loss)
            torch.save(ema_model.state_dict(), f"{exp + '_models'}/best_model.pt")
            if ema_time_model is not None:
                torch.save(ema_time_model.state_dict(), f"{exp + '_models'}/best_time_model.pt")
            
        if epoch % 5 == 0 or epoch == epochs - 1:
            torch.save(ema_model.state_dict(), f"{exp + '_models'}/model_{epoch}.pt")
            if ema_time_model is not None:
                torch.save(ema_time_model.state_dict(), f"{exp + '_models'}/time_model_{epoch}.pt")
        
        if previous_loss <= val_loss:
            early_stop_counter += 1
        else:
            previous_loss = val_loss
            early_stop_counter = 0
            
        if early_stop_counter > 100000:
            print('interrupted by early stopping !')
            return model, time_model, ema_model, ema_time_model
        
    return model, time_model, ema_model, ema_time_model


def save_images(batch, name, nrow=1):
    name = name
    image_grid = torchvision.utils.make_grid(batch, nrow)
    torchvision.utils.save_image(image_grid, name)
    Image.open(name)
    

if __name__ == "__main__":
    
    experiment = sys.argv[1]
    cond = sys.argv[2]
    
    if experiment not in ['cifar', 'mnist']:
        raise Exception(f"experiment {experiment} not supported, only {['cifar', 'mnist']}")
    if cond not in ['cond', 'pred']:
        raise Exception(f"conditioning {cond} not supported, only {['cond', 'pred']}")
    
    print(device)
    print(experiment + ' ' + cond)
    
    T = 1000

    alphas_bar = cosine_decay(T)
    sqrt_alphas_bar =  np.sqrt(alphas_bar)
    betas_bar = 1.0 - alphas_bar
    sqrt_betas_bar = np.sqrt(betas_bar)
    alphas = np.concatenate([alphas_bar[:1], alphas_bar[1:] / alphas_bar[:-1]])
    sqrt_alphas = np.sqrt(alphas)
    betas = 1.0 - alphas
    sigmas = betas[1:] * (betas_bar[:-1] / betas_bar[1:])
    sqrt_sigmas = np.sqrt(np.append(sigmas[0], sigmas))

    alphas_bar = torch.from_numpy(alphas_bar).float().to(device)
    sqrt_alphas_bar = torch.from_numpy(sqrt_alphas_bar).float().to(device)
    betas_bar = torch.from_numpy(betas_bar).float().to(device)
    sqrt_betas_bar = torch.from_numpy(sqrt_betas_bar).float().to(device)
    alphas = torch.from_numpy(alphas).float().to(device)
    sqrt_alphas = torch.from_numpy(sqrt_alphas).float().to(device)
    betas = torch.from_numpy(betas).float().to(device)
    sigmas = torch.from_numpy(sigmas).float().to(device)
    sqrt_sigmas = torch.from_numpy(sqrt_sigmas).float().to(device)
    
    data_root = './toy_data'
    if experiment == 'mnist':
        train_dataset = MNIST(data_root, train=True,
                                     download=True, transform=transforms.Compose(
                                        [transforms.ToTensor(),
                     transforms.Normalize((0.1307,), (0.3081,))]
                ))
        test_dataset = MNIST(data_root, train=False,
                                    download=True, transform=transforms.Compose(
                    [transforms.ToTensor(),
                     transforms.Normalize((0.1307,), (0.3081,))]
            ))
        
        model = UNet2DModel(
        sample_size=28,  # the target image resolution
        in_channels=1,  # the number of input channels, 3 for RGB images
        out_channels=1,  # the number of output channels
        layers_per_block=2,  # how many ResNet layers to use per UNet block
        block_out_channels=(64, 128, 128),  # the number of output channels for each UNet block
        down_block_types=(
            "DownBlock2D",  # a regular ResNet downsampling block
            "DownBlock2D",  # a ResNet downsampling block with spatial self-attention
            "DownBlock2D",  # a regular ResNet downsampling block
        ),
        up_block_types=(
            "UpBlock2D",  # a regular ResNet upsampling block
            "UpBlock2D",  # a ResNet upsampling block with spatial self-attention
            "UpBlock2D",  # a regular ResNet upsampling block
        ),
        )

        time_model = UNet2DModel(
            sample_size=28,  # the target image resolution
            in_channels=1,  # the number of input channels, 3 for RGB images
            out_channels=1,  # the number of output channels
            layers_per_block=2,  # how many ResNet layers to use per UNet block
            block_out_channels=(64, 128, 128),  # the number of output channels for each UNet block
            down_block_types=(
                "DownBlock2D",  # a regular ResNet downsampling block
                "DownBlock2D",  # a ResNet downsampling block with spatial self-attention
                "DownBlock2D",  # a regular ResNet downsampling block
            ),
            up_block_types=(
                "UpBlock2D",  # a regular ResNet upsampling block
                "UpBlock2D",  # a ResNet upsampling block with spatial self-attention
                "UpBlock2D",  # a regular ResNet upsampling block
            ),
        )
    else:
        train_dataset = CIFAR10(data_root, train=True,
                                 download=True, transform=transforms.Compose(
                [transforms.ToTensor(),
                 transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
                 ]
            )
            )
        test_dataset = CIFAR10(data_root, train=False,
                                    download=True, transform=transforms.Compose(
                    [transforms.ToTensor(),
                     transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
                     ]
                )
                )
        
        model = UNet2DModel(
            sample_size=32,  # the target image resolution
            in_channels=3,  # the number of input channels, 3 for RGB images
            out_channels=3,  # the number of output channels
            layers_per_block=2,  # how many ResNet layers to use per UNet block
            block_out_channels=( 128,
            256,
            256,
            256),  # the number of output channels for each UNet block
            down_block_types=(
                "AttnDownBlock2D",
                "AttnDownBlock2D",
                "AttnDownBlock2D",
                "AttnDownBlock2D"
            ),
            up_block_types=(
                "AttnUpBlock2D",
                "AttnUpBlock2D",
                "AttnUpBlock2D",
                "AttnUpBlock2D"
            ),
            dropout=0.3,
            attention_head_dim = None
        )
        
        time_model = UNet2DModel(
            sample_size=32,  # the target image resolution
            in_channels=3,  # the number of input channels, 3 for RGB images
            out_channels=3,  # the number of output channels
            layers_per_block=2,  # how many ResNet layers to use per UNet block
            block_out_channels=( 128,
            256,
            256,
            256),  # the number of output channels for each UNet block
            down_block_types=(
                "AttnDownBlock2D",
                "AttnDownBlock2D",
                "AttnDownBlock2D",
                "AttnDownBlock2D"
            ),
            up_block_types=(
                "AttnUpBlock2D",
                "AttnUpBlock2D",
                "AttnUpBlock2D",
                "AttnUpBlock2D"
            ),
            dropout=0.3,
            attention_head_dim = None
        )
    
    if cond == 'cond':
        time_model = None
    
    print("num model's parameters", sum(p.numel() for p in model.parameters()))
    
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)
    val_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False)
    
    params = list(model.parameters()) + list(time_model.parameters()) if time_model is not None else model.parameters()
    
    optimizer = torch.optim.AdamW(params, lr=1e-4)
    scheduler =  ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=4000, verbose=True)
    
    model, time_model, ema_model, ema_time_model = train_loop(train_loader, val_loader, optimizer, model, time_model=time_model, epochs=2500, exp=f"{experiment}_{cond}_2", scheduler=scheduler)