# -*- coding: utf-8 -*-
"""
Created on Fri Apr 22 15:50:37 2022
This code is based on the work of 
https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial8/Deep_Energy_Models.html

"""
## Standard libraries
import os
import json
import math
import numpy as np
import random



## Imports for plotting
import matplotlib.pyplot as plt
from matplotlib import cm
#matplotlib inline
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg', 'pdf') # For export
from matplotlib.colors import to_rgb

import matplotlib
from mpl_toolkits.mplot3d.axes3d import Axes3D
from mpl_toolkits.mplot3d import proj3d
matplotlib.rcParams['lines.linewidth'] = 2.0
import seaborn as sns
sns.reset_orig()

 
## PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
# Torchvision
import torchvision
from torchvision.datasets import MNIST
from torchvision import transforms
# PyTorch Lightning
try:
    import pytorch_lightning as pl
except ModuleNotFoundError: # Google Colab does not have PyTorch Lightning installed by default. Hence, we do it here if necessary
    print('Please install pytorch-lightning then retry')
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint


import torchmetrics
#from torchmetrics.image.fid import FrechetInceptionDistance
#from torchmetrics.image.inception import InceptionScore

from fid import FrechetInceptionDistance
from inception import InceptionScore

import PIL.Image as Image


# Setting the seed
seed_in =10 #40 #  21 #  10 # 
pl.seed_everything(seed_in) # 42
# Path to the folder where the datasets are/should be downloaded (e.g. CIFAR10)
DATASET_PATH = "../data"



curr_rate=0.000000000001 #0.0 #  
if curr_rate==0:  
   CHECKPOINT_PATH = "saved_models/model_" + str(seed_in) + "ebm_standard"
else :
   CHECKPOINT_PATH = "saved_models/model_" + str(seed_in) + "ebm_ours" + str(curr_rate)
   

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.determinstic = True
torch.backends.cudnn.benchmark = False
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)

##################dataset ##########################

# Transformations applied on each image => make them a tensor and normalize between -1 and 1
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))
                               ])

# Loading the training dataset. We need to split it into a training and validation part
train_set = MNIST(root=DATASET_PATH, train=True, transform=transform, download=True)

# Loading the test set
test_set = MNIST(root=DATASET_PATH, train=False, transform=transform, download=True)

# We define a set of data loaders that we can use for various purposes later.
# Note that for actually training a model, we will use different data loaders
# with a lower batch size.
train_loader = data.DataLoader(train_set, batch_size=128, shuffle=True,  drop_last=True,  num_workers=4, pin_memory=True)
test_loader  = data.DataLoader(test_set,  batch_size=128, shuffle=False, drop_last=False, num_workers=4)

###############  CNN Model ###############

class Swish(nn.Module):

    def forward(self, x):
        return x * torch.sigmoid(x)


class CNNModel(nn.Module):

    def __init__(self, hidden_features=32, out_dim=1, **kwargs):
        super().__init__()
        # We increase the hidden dimension over layers. Here pre-calculated for simplicity.
        c_hid1 = hidden_features//2
        c_hid2 = hidden_features
        c_hid3 = hidden_features*2

        # Series of convolutions and Swish activation functions
        self.cnn_layers = nn.Sequential(
                nn.Conv2d(1, c_hid1, kernel_size=5, stride=2, padding=4), # [16x16] - Larger padding to get 32x32 image
                Swish(),
                nn.Conv2d(c_hid1, c_hid2, kernel_size=3, stride=2, padding=1), #  [8x8]
                Swish(),
                nn.Conv2d(c_hid2, c_hid3, kernel_size=3, stride=2, padding=1), # [4x4]
                Swish(),
                nn.Conv2d(c_hid3, c_hid3, kernel_size=3, stride=2, padding=1), # [2x2]
                Swish(),
                nn.Flatten(),
                nn.Linear(c_hid3*4, c_hid3),
                Swish()
                #
        )
        self.output_layer = nn.Linear(c_hid3, out_dim)        

    def forward(self, x):
        intermediate_output = self.cnn_layers(x)
        x = self.output_layer(intermediate_output).squeeze(dim=-1)
        return x, intermediate_output 


################ sampling bufffer ###################
class Sampler:

    def __init__(self, model, img_shape, sample_size, max_len=8192):
        """
        Inputs:
            model - Neural network to use for modeling E_theta
            img_shape - Shape of the images to model
            sample_size - Batch size of the samples
            max_len - Maximum number of data points to keep in the buffer
        """
        super().__init__()
        self.model = model
        self.img_shape = img_shape
        self.sample_size = sample_size
        self.max_len = max_len
        self.examples = [(torch.rand((1,)+img_shape)*2-1) for _ in range(self.sample_size)]

    def sample_new_exmps(self, steps=60, step_size=10):
        """
        Function for getting a new batch of "fake" images.
        Inputs:
            steps - Number of iterations in the MCMC algorithm
            step_size - Learning rate nu in the algorithm above
        """
        # Choose 95% of the batch from the buffer, 5% generate from scratch
        n_new = np.random.binomial(self.sample_size, 0.05)
        rand_imgs = torch.rand((n_new,) + self.img_shape) * 2 - 1
        old_imgs = torch.cat(random.choices(self.examples, k=self.sample_size-n_new), dim=0)
        inp_imgs = torch.cat([rand_imgs, old_imgs], dim=0).detach().to(device)

        # Perform MCMC sampling
        inp_imgs = Sampler.generate_samples(self.model, inp_imgs, steps=steps, step_size=step_size)

        # Add new images to the buffer and remove old ones if needed
        self.examples = list(inp_imgs.to(torch.device("cpu")).chunk(self.sample_size, dim=0)) + self.examples
        self.examples = self.examples[:self.max_len]
        return inp_imgs

    @staticmethod
    def generate_samples(model, inp_imgs, steps=60, step_size=10, return_img_per_step=False):
        """
        Function for sampling images for a given model.
        Inputs:
            model - Neural network to use for modeling E_theta
            inp_imgs - Images to start from for sampling. If you want to generate new images, enter noise between -1 and 1.
            steps - Number of iterations in the MCMC algorithm.
            step_size - Learning rate nu in the algorithm above
            return_img_per_step - If True, we return the sample at every iteration of the MCMC
        """
        # Before MCMC: set model parameters to "required_grad=False"
        # because we are only interested in the gradients of the input.
        is_training = model.training
        model.eval()
        for p in model.parameters():
            p.requires_grad = False
        inp_imgs.requires_grad = True

        # Enable gradient calculation if not already the case
        had_gradients_enabled = torch.is_grad_enabled()
        torch.set_grad_enabled(True)

        # We use a buffer tensor in which we generate noise each loop iteration.
        # More efficient than creating a new tensor every iteration.
        noise = torch.randn(inp_imgs.shape, device=inp_imgs.device)

        # List for storing generations at each step (for later analysis)
        imgs_per_step = []

        # Loop over K (steps)
        for _ in range(steps):
            # Part 1: Add noise to the input.
            noise.normal_(0, 0.005)
            inp_imgs.data.add_(noise.data)
            inp_imgs.data.clamp_(min=-1.0, max=1.0)

            # Part 2: calculate gradients for the current input.
            ### my code #####
            out_imgs, _ = model(inp_imgs)
            out_imgs = -out_imgs
            out_imgs.sum().backward()
            inp_imgs.grad.data.clamp_(-0.03, 0.03) # For stabilizing and preventing too high gradients

            # Apply gradients to our current samples
            inp_imgs.data.add_(-step_size * inp_imgs.grad.data)
            inp_imgs.grad.detach_()
            inp_imgs.grad.zero_()
            inp_imgs.data.clamp_(min=-1.0, max=1.0)

            if return_img_per_step:
                imgs_per_step.append(inp_imgs.clone().detach())

        # Reactivate gradients for parameters for training
        for p in model.parameters():
            p.requires_grad = True
        model.train(is_training)

        # Reset gradient calculation to setting before this function
        torch.set_grad_enabled(had_gradients_enabled)

        if return_img_per_step:
            return torch.stack(imgs_per_step, dim=0)
        else:
            return inp_imgs

def interpolate(batch):
    arr = []
    for img in batch:
        #img= (((img* 0.5) + 0.5)* 255)
        img=torch.tile(img, (3,1,1))
        pil_img = transforms.ToPILImage()(img)
        resized_img = pil_img.resize((299,299), Image.BILINEAR)
        resized_imgTen =transforms.ToTensor()(resized_img) 
	
        arr.append(resized_imgTen.type(torch.uint8))
    result_final = torch.stack(arr)
    return result_final

def compute_nnl_val(preds,gnds):
    gnds=gnds.view(-1,28*28)
    preds=preds.view(-1,1,28*28)
    preds= (preds* 0.5) + 0.5
    gnds= (((gnds* 0.5) + 0.5)).type(torch.uint8) 
    gnds= gnds.type(torch.LongTensor)
    
    #create dummpy variable for class zero probability
    preds_temp = 1 - preds
    preds_full = torch.cat([preds_temp, preds], dim=1)
    
    #print('preds_full.size()', preds_full.size())   

    #print('gnds.size()', gnds.size())   

    nnl_value= F.nll_loss(torch.log(preds_full+0.00001).cpu(), gnds).item()
    return nnl_value

########### deep energy model#####################
class DeepEnergyModel(pl.LightningModule):

    def __init__(self, img_shape, batch_size, alpha=0.1, lr=1e-4, beta1=0.0,div_rate=0.005, **CNN_args):
        super().__init__()
        self.save_hyperparameters()
        self.div_rate = div_rate
        self.cnn = CNNModel(**CNN_args)
        self.sampler = Sampler(self.cnn, img_shape=img_shape, sample_size=batch_size)
        self.example_input_array = torch.zeros(1, *img_shape)
        self.fid_metric = FrechetInceptionDistance()
        self.si_metric = InceptionScore()
        self.all_fid_scores= []
        self.all_si_scores= []
        self.all_div_scores= []
        self.all_div_scores_means= []
        self.all_cdiv_scores= []
        self.all_cdiv_scores_means= []

        self.all_nnvalue_scores= []
        self.all_nnvalue_scores_means= []


        #self._currentfid_trainbest = 100
        #self.update_best= True
        #self._currentfid_testbest = 100

    def forward(self, x):
        z,inter_ = self.cnn(x)
        return z,inter_

    def configure_optimizers(self):
        # Energy models can have issues with momentum as the loss surfaces changes with its parameters.
        # Hence, we set it to 0 by default.
        optimizer = optim.Adam(self.parameters(), lr=self.hparams.lr, betas=(self.hparams.beta1, 0.999))
        scheduler = optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.97) # Exponential decay over epochs
        return [optimizer], [scheduler]

    def training_step(self, batch, batch_idx):
        # We add minimal noise to the original images to prevent the model from focusing on purely "clean" inputs
        real_imgs, _ = batch
        small_noise = torch.randn_like(real_imgs) * 0.005
        real_imgs.add_(small_noise).clamp_(min=-1.0, max=1.0)

        # Obtain samples
        fake_imgs = self.sampler.sample_new_exmps(steps=60, step_size=10)

        # Predict energy score for all images
        inp_imgs = torch.cat([real_imgs, fake_imgs], dim=0)
############# my code #############################################################
        all_outs, all_inter = self.cnn(inp_imgs)
        real_out, fake_out =all_outs.chunk(2, dim=0)
        #we only use the intermediate rep of real images 
        real_inter, _=all_inter.chunk(2, dim=0)
##########################################################################

        # Calculate losses
        reg_loss = self.hparams.alpha * (real_out ** 2 + fake_out ** 2).mean()
        cdiv_loss = fake_out.mean() - real_out.mean()
        loss = reg_loss + cdiv_loss
############# my code #############################################################
        ## add feature diversity to the total loss

        #compute pairwise disntace between features...  
        pos_inter_feat =  real_inter
        batch_size = pos_inter_feat.size(dim=0)
        pos_inter_feat =  pos_inter_feat.view(batch_size,-1,1)
        pos_inter_feat_transpose =  pos_inter_feat.view(batch_size,1,-1)
        dist= pos_inter_feat - pos_inter_feat_transpose
        dist = dist **2  

        total_diversity = dist.sum((1,2))
        total_diversity =total_diversity / 2 #we devide by 2 as in dist matrix every item is repeated two times...
        total_diversity =total_diversity.view(-1)
	
        loss = loss  -  self.div_rate * total_diversity.mean()
        #print('...........: ',self.div_rate) 

##########################################################################
########## log log log #############
        curr_diversity =  total_diversity.mean()        
        self.all_div_scores.append(curr_diversity.cpu())

##########################################################################


        # Logging
        self.log('loss', loss)
        self.log('loss_regularization', reg_loss)
        self.log('loss_contrastive_divergence', cdiv_loss)
        self.log('metrics_avg_real', real_out.mean())
        self.log('metrics_avg_fake', fake_out.mean())
        return loss

    def validation_step(self, batch, batch_idx):
        # For validating, we calculate the contrastive divergence between purely random images and unseen examples
        # Note that the validation/test step of energy-based models depends on what we are interested in the model
        real_imgs, _ = batch
        fake_imgs = torch.rand_like(real_imgs) * 2 - 1

        inp_imgs = torch.cat([real_imgs, fake_imgs], dim=0)

############# my code #############################################################
        all_outs,_= self.cnn(inp_imgs)
        real_out, fake_out = all_outs.chunk(2, dim=0)
##########################################################################

        cdiv = fake_out.mean() - real_out.mean()
        self.log('val_contrastive_divergence', cdiv)
        self.log('val_fake_out', fake_out.mean())
        self.log('val_real_out', real_out.mean())

        self.all_cdiv_scores.append(cdiv.cpu())


        ### we compute fid score between real images and generated images:

        # Obtain generated samples
        fake_gen_imgs = self.sampler.sample_new_exmps(steps=60, step_size=10)
 
        fake_gen_imgs_temp = interpolate(fake_gen_imgs).cuda()
        real_imgs_temp = interpolate(real_imgs).cuda()
        self.fid_metric.update(real_imgs_temp, real=True)
        self.fid_metric.update(fake_gen_imgs_temp, real=False)        

        ### compute SI score
        self.si_metric.update(fake_gen_imgs_temp)
        if real_imgs.size(0)==128:
                curr_nnl = compute_nnl_val(fake_gen_imgs,real_imgs)
                self.all_nnvalue_scores.append(curr_nnl)


    def validation_epoch_end(self, validation_step_outputs):
        fid_score = self.fid_metric.compute()
        self.log('fid_score is:',fid_score)
        self.all_fid_scores.append(fid_score.cpu().numpy())
        np.savetxt(CHECKPOINT_PATH+'/mnist_fid_scores.txt', np.array(self.all_fid_scores), delimiter=',')            

        si_score,_ = self.si_metric.compute()
        self.log('si_score is:',si_score)
        self.all_si_scores.append(si_score.cpu().numpy())

        np.savetxt(CHECKPOINT_PATH+'/mnist_si_scores.txt', np.array(self.all_si_scores), delimiter=',')            


        self.all_div_scores_means.append(np.mean(np.array(self.all_div_scores)))
        np.savetxt(CHECKPOINT_PATH+'/mnist_diversity_scores.txt', np.array(self.all_div_scores_means), delimiter=',')
        self.all_div_scores = []


        self.all_cdiv_scores_means.append(np.mean(np.array(self.all_cdiv_scores)))
        np.savetxt(CHECKPOINT_PATH+'/mnist_cdiv_scores.txt', np.array(self.all_cdiv_scores_means), delimiter=',')
        self.all_cdiv_scores = []

        self.all_nnvalue_scores_means.append(np.mean(np.array(self.all_nnvalue_scores)))
        np.savetxt(CHECKPOINT_PATH+'/mnist_nnloss_scores.txt', np.array(self.all_nnvalue_scores_means), delimiter=',')
        self.all_nnvalue_scores = []


        self.fid_metric.reset()
        self.si_metric.reset()

################ callbacks ###############################
class GenerateCallback(pl.Callback):

    def __init__(self, batch_size=8, vis_steps=8, num_steps=256, every_n_epochs=5):
        super().__init__()
        self.batch_size = batch_size         # Number of images to generate
        self.vis_steps = vis_steps           # Number of steps within generation to visualize
        self.num_steps = num_steps           # Number of steps to take during generation
        self.every_n_epochs = every_n_epochs # Only save those images every N epochs (otherwise tensorboard gets quite large)

    def on_epoch_end(self, trainer, pl_module):
        # Skip for all other epochs
        if trainer.current_epoch % self.every_n_epochs == 0:
            # Generate images
            imgs_per_step = self.generate_imgs(pl_module)
            # Plot and add to tensorboard
            for i in range(imgs_per_step.shape[1]):
                step_size = self.num_steps // self.vis_steps
                imgs_to_plot = imgs_per_step[step_size-1::step_size,i]
                grid = torchvision.utils.make_grid(imgs_to_plot, nrow=imgs_to_plot.shape[0], normalize=True, range=(-1,1))
                trainer.logger.experiment.add_image(f"generation_{i}", grid, global_step=trainer.current_epoch)

    def generate_imgs(self, pl_module):
        pl_module.eval()
        start_imgs = torch.rand((self.batch_size,) + pl_module.hparams["img_shape"]).to(pl_module.device)
        start_imgs = start_imgs * 2 - 1
        torch.set_grad_enabled(True)  # Tracking gradients for sampling necessary
        imgs_per_step = Sampler.generate_samples(pl_module.cnn, start_imgs, steps=self.num_steps, step_size=10, return_img_per_step=True)
        torch.set_grad_enabled(False)
        pl_module.train()
        return imgs_per_step




class SamplerCallback(pl.Callback):

    def __init__(self, num_imgs=32, every_n_epochs=5):
        super().__init__()
        self.num_imgs = num_imgs             # Number of images to plot
        self.every_n_epochs = every_n_epochs # Only save those images every N epochs (otherwise tensorboard gets quite large)

    def on_epoch_end(self, trainer, pl_module):
        if trainer.current_epoch % self.every_n_epochs == 0:
            exmp_imgs = torch.cat(random.choices(pl_module.sampler.examples, k=self.num_imgs), dim=0)
            grid = torchvision.utils.make_grid(exmp_imgs, nrow=4, normalize=True, range=(-1,1))
            trainer.logger.experiment.add_image("sampler", grid, global_step=trainer.current_epoch)


class OutlierCallback(pl.Callback):

    def __init__(self, batch_size=1024):
        super().__init__()
        self.batch_size = batch_size

    def on_epoch_end(self, trainer, pl_module):
        with torch.no_grad():
            pl_module.eval()
            rand_imgs = torch.rand((self.batch_size,) + pl_module.hparams["img_shape"]).to(pl_module.device)
            rand_imgs = rand_imgs * 2 - 1.0
            rand_out_all,_ = pl_module.cnn(rand_imgs) 
            rand_out =rand_out_all.mean()
            pl_module.train()

        trainer.logger.experiment.add_scalar("rand_out", rand_out, global_step=trainer.current_epoch)





############ running the model ##################
def train_model(**kwargs):
    # Create a PyTorch Lightning trainer with the generation callback
    trainer = pl.Trainer(default_root_dir=os.path.join(CHECKPOINT_PATH, "MNIST"),
                         gpus=1 if str(device).startswith("cuda") else 0,
                         max_epochs=60,
                         gradient_clip_val=0.1,
                         callbacks=[ModelCheckpoint(save_weights_only=True, mode="min", monitor='val_contrastive_divergence'),
                                    GenerateCallback(every_n_epochs=5),
                                    SamplerCallback(every_n_epochs=5),
                                    OutlierCallback(),
                                    LearningRateMonitor("epoch")
                                   ],
                        progress_bar_refresh_rate=1)
    # Check whether pretrained model exists. If yes, load it and skip training
    pretrained_filename = os.path.join(CHECKPOINT_PATH, "MNIST.ckpt")
    if os.path.isfile(pretrained_filename):
        print("Found pretrained model, loading...")
        model = DeepEnergyModel.load_from_checkpoint(pretrained_filename)
    else:
        pl.seed_everything(seed_in)
        model = DeepEnergyModel(**kwargs)
        trainer.fit(model, train_loader, test_loader)
        model = DeepEnergyModel.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)
    # No testing as we are more interested in other properties
    return model


model = train_model(img_shape=(1,28,28),
                    batch_size=train_loader.batch_size,
                    lr=1e-4,
                    beta1=0.0,div_rate=curr_rate)



model.to(device)
pl.seed_everything(43)
callback = GenerateCallback(batch_size=4, vis_steps=8, num_steps=256)
imgs_per_step = callback.generate_imgs(model)
imgs_per_step = imgs_per_step.cpu()

print('done done done')

for i in range(imgs_per_step.shape[1]):
    step_size = callback.num_steps // callback.vis_steps
    imgs_to_plot = imgs_per_step[step_size-1::step_size,i]
    imgs_to_plot = torch.cat([imgs_per_step[0:1,i],imgs_to_plot], dim=0)
    grid = torchvision.utils.make_grid(imgs_to_plot, nrow=imgs_to_plot.shape[0], normalize=True, range=(-1,1), pad_value=0.5, padding=2)
    grid = grid.permute(1, 2, 0)
    plt.figure(figsize=(8,8))
    plt.imshow(grid)
    plt.xlabel("Generation iteration")
    plt.xticks([(imgs_per_step.shape[-1]+2)*(0.5+j) for j in range(callback.vis_steps+1)],
               labels=[1] + list(range(step_size,imgs_per_step.shape[0]+1,step_size)))
    plt.yticks([])

    plt.show()
    plt.savefig("figs/" + str(seed_in) + "mnistours_" + str(curr_rate)+ "_" + str(i) + ".pdf", dpi=150)
plt.savefig("figs/" + str(seed_in) + "mnistours_" + str(curr_rate)+ ".pdf", dpi=150)
plt.savefig("figs/" + str(seed_in) + "mnistours_" + str(curr_rate)+ ".png", dpi=150)



