import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import numpy as np
import math
from scipy.stats import entropy
from torch.distributions.multivariate_normal import MultivariateNormal
from vae import *
from sampler_NN import *
import random

def split_data(joint_img):
    cor_img = torch.zeros((joint_img.shape[0], joint_img.shape[1], 14, 14), device=joint_img.device)
    bs = joint_img.shape[0]
    for i in range(bs):
        start_y = random.randint(0, 14)
        cor_img[i] = joint_img.detach().clone()[i, :, start_y:(start_y + 14), :14]
    
    img = joint_img.detach()
    img[..., :, :14] = 0
    return img, cor_img

def train_vae(model, total_epochs, train_loader, test_loader, beta):
    criterion = nn.MSELoss(reduction='sum')
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    total_epochs= 30
    for epoch in range(total_epochs):
        model.train()
        print(f'Epoch {epoch + 1:3d}\n=========', flush=True)
        for i, data in enumerate(iter(train_loader)):
            optimizer.zero_grad()
            joint_img, _ = data
            img, cor_img = split_data(joint_img)
            img = img.cuda().float()
            cor_img = cor_img.cuda().float()
        
            x_hat, mu, logvar, z, hy_pj = model(img, cor_img)
            
            kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
            distortion = criterion(x_hat, img)
            
            loss = beta*distortion + kld_loss
            loss.backward()
            
            optimizer.step()
            if i % 100 == 0:
                print(f'{i:3d}/{len(train_loader):3d}: {distortion.item() / (64*28*28):.6f}', flush=True)

        with torch.no_grad():
            model.eval()
            tot_distortion = 0.0
            total_N = 0.0
            all_zs = []
            for i, data in enumerate(iter(test_loader)):
                joint_img, _ = data
                img, cor_img = split_data(joint_img)
                img = img.cuda().float()
                cor_img = cor_img.cuda().float()

                x_hat, mu, logvar, z, hy_pj = model(img, cor_img)
                total_N += len(img)
                tot_distortion += criterion(x_hat, img)
                all_zs.append(z)
            
            all_zs = torch.cat(all_zs)
            print(f'Test:    {tot_distortion.item() / (28*28*total_N):.6f}\n', flush=True)
    return model

def train_classifier(classifier, model, total_epochs, train_loader, test_loader):
    bce_fn = nn.BCELoss()
    opt_nce = torch.optim.Adam(classifier.parameters())

    for epoch in range(total_epochs):
        model.eval()
        classifier.train()
        print(f'Epoch {epoch + 1:3d}\n=========', flush=True)
        for i, data in enumerate(iter(train_loader)):
            opt_nce.zero_grad()
            
            with torch.no_grad():
                joint_img, _ = data
                img, cor_img = split_data(joint_img)
                img = img.cuda().float()
                cor_img = cor_img.cuda().float()

                x_hat, mu, logvar, z, hy_pj = model(img, cor_img)
                #x_hat,  hy_pj = model(img, cor_img)
                mini_bs = img.shape[0]
                U_pos = z[:int(mini_bs/2)]
                U_neg = torch.randn_like(U_pos)*np.sqrt(1.0)
            
            U_pos.cuda()
            U_neg.cuda()
            U = torch.cat((U_pos, U_neg), dim=0)
            prob_, llr_ = classifier(cor_img, U)
            
            target_pos = torch.ones((int(mini_bs/2), 1))
            target_neg = torch.zeros((int(mini_bs/2), 1))
            target= torch.cat((target_pos, target_neg),dim=0).cuda()
            
            loss = bce_fn(torch.flatten(prob_), torch.flatten(target))
            loss.backward()
            
            opt_nce.step()
            if i % 100 == 0:
                print(f'{i:3d}/{len(train_loader):3d}: {loss.item() / 64:.6f}', flush=True)
            
        with torch.no_grad():
            model.eval()
            classifier.eval()
            total_val_loss = 0
            for i, data in enumerate(iter(test_loader)):
                joint_img, _ = data
                img, cor_img = split_data(joint_img)
                img = img.cuda().float()
                cor_img = cor_img.cuda().float()

                x_hat, mu, logvar, z, hy_pj = model(img, cor_img)
                mini_bs = img.shape[0]
                U_pos = z[:int(mini_bs/2)]
                U_neg = torch.randn_like(U_pos)*np.sqrt(1.0)

                U_pos.cuda()
                U_neg.cuda()
                U = torch.cat((U_pos, U_neg), dim=0)
                prob_, llr_ = classifier(cor_img, U)

                target_pos = torch.ones((int(mini_bs/2), 1))
                target_neg = torch.zeros((int(mini_bs/2), 1))
                target= torch.cat((target_pos, target_neg),dim=0).cuda()

                loss = bce_fn(torch.flatten(prob_), torch.flatten(target))
                total_val_loss+=loss.item()
                
            total_N = len(test_loader.dataset)
            print(f'Test:    {total_val_loss / total_N:.6f}\n', flush=True)
    return classifier

# Define the transformation to apply to the dataset
transform = transforms.Compose([
    transforms.ToTensor(),  # Convert the images to tensors
    transforms.Normalize((0.5,), (0.5,))  # Normalize the pixel values to the range [-1, 1]
])

# Load the training dataset
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)

# Create a data loader for the training dataset
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)

# Load the test dataset
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transform, download=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

total_epochs = 30
latent_dim = 4
beta = [0.15, 0.35, 0.55, 0.75, 0.95]
for b in beta:
    print(f'Training model for beta={b:.2f}', flush=True)
    model = VAE(latent_dim=latent_dim).to(device)
    classifier = NCEClassifier(u_dim=latent_dim).to(device)
    model = train_vae(model, total_epochs, train_loader, test_loader, b)
    classifier = train_classifier(classifier, model, total_epochs, train_loader, test_loader)

    torch.save(model.state_dict(), f'model/model_{int(b*100):2d}.pth')
    torch.save(classifier.state_dict(), f'model/classifier_{int(b*100):2d}.pth')
