# Compute probability embedding u in pixel space using KME 
# Train Score Neural Operator with u in latent space (end-to-end training)
import time
import numpy as np
import torch
from torch.utils.data import Dataset
import torchvision
import matplotlib.pyplot as plt

from dataprocess.data import training_idx
from utils.model import LatentScoreOperator
from utils.loss import score_condition_loss
from dataprocess.compute_u import compute_u_train_test
from utils.save_images import generate_images
from utils.generate import generate_training_samples, generate_samples_z
from torchvision.utils import save_image
from dataprocess.classifier import evaluate_accuracy
from dataprocess.data import MyDataset_XU2

seed = 2
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
t0 = time.time()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_data_path = "data/mnist2d_train.npy"
num_samples = 2000
X_dim = 1024

X = np.load(train_data_path).astype(np.float32)/255
X = X[:,:num_samples,:]
indices = training_idx
remaining_indices = np.setdiff1d(np.arange(100), indices)
print(indices)

X_selected = X[indices]
X_unselected = X[remaining_indices]


mode =  0
if mode == 0:
    # first time to train
    pretrained = False
    training = True
    load_stat = False
    load_u_train = False

elif mode == 1:
    # finetune
    pretrained = True
    training = True
    load_stat = True
    load_u_train = True

elif mode == 2:
    # test
    pretrained = True
    training = False
    load_stat = False
    load_u_train = True



if load_u_train:
    u_train = torch.load('weights/end2end_u_train_'+str(num_samples)+'.pth')
    u_test = torch.load('weights/end2end_u_test_'+str(num_samples)+'.pth')
else:
    t_u = time.time() 
    u_train, u_test = compute_u_train_test(torch.from_numpy(X.reshape(100,num_samples,X_dim)))
    print("u computation time:", time.time() - t_u)
    torch.save(u_train, 'weights/end2end_u_train_'+str(num_samples)+'.pth')
    torch.save(u_test, 'weights/end2end_u_test_'+str(num_samples)+'.pth')

X_data = torch.from_numpy(X_selected).reshape(70, num_samples, X_dim).to(device)
X_test = torch.from_numpy(X_unselected).reshape(-1, X_dim).to(device)


model = LatentScoreOperator(x_dim=X_dim, h_dim1=512, h_dim2=256, z_dim=10, num_examples=70, num_samples=num_samples, u=True)

dataset = MyDataset_XU2(X_data, u_train)
dloader = torch.utils.data.DataLoader(dataset, batch_size=5, shuffle=True)


opt = torch.optim.AdamW(model.parameters(), lr=3e-4)

model = model.to(device)
current_epochs, totol_epoch = 0, 400000



if load_stat:
    stats = torch.load('stats/end2end_kme_'+str(num_samples)+'_'+str(seed)+'.pth')
    current_epochs = int(stats["epoch"][-1])
    t0 = time.time() -  stats['time'][-1] 
else:
    epochs, total_time, train_acc, test_acc = [],[],[],[]
    stats = {"epoch":epochs, "time":total_time, "train_accuracy":train_acc, "test_accuracy":test_acc}
    t0 = time.time()
 

if pretrained:
    model_state_dict = torch.load('weights/end2end_kme_'+str(num_samples)+'_'+str(seed)+'.pth')
    model.load_state_dict(model_state_dict)


if mode == 2:
    num_test_samples = 100
    for k in range(70):
        samples = X_data[k,:num_test_samples]   
        save_image(samples.view(-1, 1, 32, 32), f'./samples/MNIST_new_'+str(k)+'_ori.png')

    num_test_samples = 100
    for k in range(70):
        u = u_train[k].reshape(1,10)
        u = u.unsqueeze(1).repeat(1,num_test_samples,1).reshape(-1,u.shape[-1])
        z = generate_samples_z(10, model.scorenet, num_test_samples, u).detach().reshape(-1,10)
        samples = model.vae.decoder(z).detach()
        save_image(samples.view(-1, 1, 32, 32), f'./samples/MNIST_new_'+str(k)+'_ref.png')



def eval_accuracy(num_test_samples=1000):
    #classification accuracy
    u = u_train.unsqueeze(1).repeat(1,num_test_samples,1).reshape(-1,u_train.shape[-1])
    #u = torch.tile(u_train.unsqueeze(1), (1,num_test_samples,1)).reshape(-1,u_train.shape[-1])
    z = generate_samples_z(10, model.scorenet, 70*num_test_samples, u).detach().reshape(-1,10)
    X = model.vae.decoder(z).detach()
    generate_images(X.reshape(70,num_test_samples,1,32,32)[:,0,:,:,:].detach().cpu().reshape(-1, 1, 32, 32),f'./samples/train.png',indices)

    Y = np.tile(indices.reshape(70,1),(1,num_test_samples)).reshape(-1)
    train_accuracy = evaluate_accuracy(X,Y)
    u = u_test.unsqueeze(1).repeat(1,num_test_samples,1).reshape(-1,u_test.shape[-1])
    #print("train classification accuracy", train_accuracy)
    #u = torch.tile(u_test.unsqueeze(1), (1,num_test_samples,1)).reshape(-1,u_test.shape[-1])
    #print(u)
    z = generate_samples_z(10, model.scorenet, 30*num_test_samples, u).detach().reshape(-1,10)
    X = model.vae.decoder(z).detach()
    generate_images(X.reshape(30,num_test_samples,1,32,32)[:,0,:,:,:].detach().cpu().reshape(-1, 1, 32, 32),f'./samples/test.png', remaining_indices)

    Y = np.tile(remaining_indices.reshape(30,1),(1,num_test_samples)).reshape(-1)
    #print(u[0],u[1],Y)
    test_accuracy = evaluate_accuracy(X,Y)
    #print("test classification accuracy", test_accuracy)
    return train_accuracy, test_accuracy


if training:
    for i_epoch in range(current_epochs,totol_epoch):
        for data, u in dloader: 
            u = u.unsqueeze(1).repeat(1,num_samples,1).reshape(-1,u_train.shape[-1])
            u = u.to(device)
            data = data.reshape(-1, X_dim)
            opt.zero_grad()
            # training step
            loss, vae_loss, score_loss = score_condition_loss(model, data, u)
            loss.backward()
            opt.step()

        # print the training stats
        if i_epoch % 100 == 0:
            print(f"{i_epoch+1} (time:{time.time() - t0}s): Loss:{loss.item()}, vae_loss:{vae_loss.item()}, score_loss:{score_loss.item()}")
            
        if (1+i_epoch) % 1000 == 0:    
            tt = time.time() - t0
            train_accuracy, test_accuracy = eval_accuracy()
            print(f"{i_epoch+1} (time:{tt}s): Train_accuracy:{train_accuracy}, Test_accuracy:{test_accuracy}")
            stats["epoch"].append(1+i_epoch)
            stats["time"].append(tt)
            stats["train_accuracy"].append(train_accuracy)
            stats["test_accuracy"].append(test_accuracy)
            torch.save(stats, 'stats/end2end_kme_'+str(num_samples)+'_'+str(seed)+'.pth')
            torch.save(model.state_dict(), 'weights/end2end_kme_'+str(num_samples)+'_'+str(seed)+'.pth')


    torch.save(stats, 'stats/end2end_kme_'+str(num_samples)+'_'+str(seed)+'.pth')
    torch.save(model.state_dict(), 'weights/end2end_kme_'+str(num_samples)+'_'+str(seed)+'.pth')     

