# Refer to https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial8/Deep_Energy_Models.html

## Standard libraries
import argparse
import os
import json
import math
import numpy as np
import random
import itertools

from tps_grid_gen import TPSGridGen

parser = argparse.ArgumentParser()
parser.add_argument('-l', '--label', type=int,
                    default=0, help='the choice of the label')
args = parser.parse_args()
print(args)

## PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
# Torchvision
import torchvision
from torchvision.datasets import MNIST
from torchvision import transforms
from torchvision.transforms import Compose, Pad, RandomChoice, Resize, RandomCrop, RandomRotation,GaussianBlur, ToTensor, ColorJitter, Lambda, RandomAffine
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

from models import DeepEnergyModel, GenerateCallback, SamplerCallback, OutlierCallback

# Path to the folder where the datasets are/should be downloaded (e.g. CIFAR10)
DATASET_PATH = "./data"
# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = "./saved_models/"

# Setting the seed
pl.seed_everything(42)

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)

# Create checkpoint path if it doesn't exist yet
os.makedirs(CHECKPOINT_PATH, exist_ok=True)

# Initialize the TPSGridGen
target_control_points = torch.tensor(list(itertools.product(torch.linspace(-1, 1, 5), torch.linspace(-1, 1, 5))), dtype=torch.float)
tps = TPSGridGen(28, 28, target_control_points)


# Define a function to apply TPS transformation to a batch of images
def tps_transform(images, source_control_points):
    source_grid = tps(source_control_points)
    source_grid = source_grid.view(images.size(0), 28, 28, 2)
    transformed_images = F.grid_sample(images, source_grid)
    return transformed_images

class TPSTransform:
    def __init__(self, tps_grid_gen, control_point_noise=0.1):
        self.tps_grid_gen = tps_grid_gen
        self.control_point_noise = control_point_noise

    def __call__(self, img):
        # Convert PIL image to PyTorch tensor
        img_tensor = transforms.functional.to_tensor(img)

        # Generate random control points
        num_control_points = self.tps_grid_gen.num_points
        target_control_points = self.tps_grid_gen.target_control_points
        source_control_points = target_control_points.view(1, num_control_points, 2) + torch.randn(1, num_control_points, 2) * self.control_point_noise

        # Apply the TPS transformation
        transformed_img_tensor = tps_transform(img_tensor.unsqueeze(0), source_control_points)

        # Convert the transformed tensor back to a PIL image
        transformed_img = transforms.functional.to_pil_image(transformed_img_tensor.squeeze(0))

        return transformed_img


# Define custom Gaussian noise function
def add_gaussian_noise(img, mean=0, std=0.01):
    noise = torch.randn(img.size()) * std + mean
    img = img + noise
    img = torch.clamp(img, 0, 1)
    return img

randomChoices = [Compose([Resize((14 + 2*j, 14 + 2*i)), Pad((7-i, 7-j)), RandomRotation(35), RandomCrop(28, padding=(7-i, 7-j))]) for i in range(4, 8) for j in range(4, 8)]

# Data augmentation transformations
transform = Compose([
    TPSTransform(tps),
    RandomChoice(randomChoices), 
    ColorJitter(brightness=(1.5, 2), contrast=0, saturation=0, hue=0),
    ToTensor(),
    Lambda(add_gaussian_noise),
    transforms.Normalize((0.5,), (0.5,))
])

# Loading the test set
test_set = MNIST(root=DATASET_PATH, train=False, transform=transform, download=True)

class DigitDataset(data.Dataset):
    def __init__(self, dataset, i, n_elements=5000):
        self.dataset = dataset
        self.i = i
        self.n_elements = n_elements

    def __len__(self):
        return self.n_elements

    def __getitem__(self, index):
        return self.dataset[self.i]

idx_map = {0: 71, 1:135, 2:186, 3:32, 4:109, 5:165, 6:21, 7:34, 8:181, 9:12}
digit_dataset = DigitDataset(test_set, idx_map[args.label], n_elements=5000)
train_loader = data.DataLoader(digit_dataset, batch_size=128, shuffle=True, num_workers=0, drop_last=True, pin_memory=True)
test_loader = data.DataLoader(digit_dataset, batch_size=256, shuffle=True, num_workers=0, drop_last=False)

checkpoint_name = f"MNIST_aug_{args.label}_single"
CHECKPOINT_PATH = os.path.join(CHECKPOINT_PATH, checkpoint_name)

def train_model(**kwargs):
    # Create a PyTorch Lightning trainer with the generation callback
    trainer = pl.Trainer(default_root_dir=CHECKPOINT_PATH,
                         accelerator="gpu" if str(device).startswith("cuda") else "cpu",
                         devices=1,
                         max_epochs=256,
                         gradient_clip_val=0.1,
                         callbacks=[
                                    # ModelCheckpoint(dirpath=CHECKPOINT_PATH, save_top_k=5, filename='{epoch}-{val_contrastive_divergence:.4f}-aug', save_weights_only=True, mode="min", monitor='val_contrastive_divergence'),
                                    ModelCheckpoint(dirpath=CHECKPOINT_PATH, save_top_k=1, filename='MNIST_aug_EBM_best', save_weights_only=True, mode="min", monitor='val_contrastive_divergence'),
                                    GenerateCallback(every_n_epochs=5, num_steps=1024),
                                    SamplerCallback(every_n_epochs=5),
                                    OutlierCallback(),
                                    LearningRateMonitor("epoch")
                                   ])
    pl.seed_everything(42)
    model = DeepEnergyModel(**kwargs)
    trainer.fit(model, train_loader, test_loader)
    model = DeepEnergyModel.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)
    return model

model = train_model(img_shape=(1,28,28),
                    batch_size=train_loader.batch_size,
                    lr=1e-4,
                    beta1=0.0)

