# Train Score Neural Operator with prototype embedding method in latent space (end-to-end training)
import time
import numpy as np
import torch
from torch.utils.data import Dataset
import torchvision
import matplotlib.pyplot as plt

from dataprocess.data import training_idx
from utils.model import LatentScoreOperator
from utils.loss import score_operator_loss
from utils.pca import compute_u_train_and_test
from utils.save_images import generate_images
from utils.generate import generate_training_samples, generate_samples_z
from torchvision.utils import save_image
from dataprocess.classifier import evaluate_accuracy
from dataprocess.data import MyDataset_X2

seed = 0
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
t0 = time.time()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_data_path = "data/mnist2d_train.npy"
num_samples = 2000
X_dim = 1024

X = np.load(train_data_path).astype(np.float32)/255
X = X[:,:num_samples,:]
indices = training_idx
remaining_indices = np.setdiff1d(np.arange(100), indices)
print(indices)

X_selected = X[indices]
X_unselected = X[remaining_indices]



X_data = torch.from_numpy(X_selected).reshape(-1, X_dim).to(device)
X_test = torch.from_numpy(X_unselected).reshape(-1, X_dim).to(device)
model = LatentScoreOperator(x_dim=X_dim, h_dim1=512, h_dim2=256, z_dim=10, num_examples=70, num_samples=num_samples, u=True)

dataset = MyDataset_X2(X_data.reshape(70,num_samples,X_dim))
dloader = torch.utils.data.DataLoader(dataset, batch_size=5, shuffle=True)
# start the training loop

opt = torch.optim.AdamW(model.parameters(), lr=3e-4)

model = model.to(device)
current_epochs, totol_epoch = 0, 200000
mode =  0
if mode == 0:
    # first time to train
    pretrained = False
    training = True
    load_stat = False

elif mode == 1:
    # finetune
    pretrained = True
    training = True
    load_stat = True
    
elif mode == 2:
    # test
    pretrained = True
    training = False
    load_stat = False

if load_stat:
    stats = torch.load('stats/end2end_'+str(num_samples)+'_'+str(seed)+'.pth')
    current_epochs = int(stats["epoch"][-1])
    t0 = time.time() -  stats['time'][-1] 
else:
    epochs, total_time, train_acc, test_acc = [],[],[],[]
    stats = {"epoch":epochs, "time":total_time, "train_accuracy":train_acc, "test_accuracy":test_acc}
    t0 = time.time()



if pretrained:
    model_state_dict = torch.load('weights/end2end_'+str(num_samples)+'_'+str(seed)+'.pth')
    model.load_state_dict(model_state_dict)

if mode == 2:
    _, _, _, Z = model.vae(X_data)
    u_train = torch.mean(Z.reshape(70,num_samples,-1),dim=1)
    _, _, _, Z = model.vae(X_test)
    u_test = torch.mean(Z.reshape(30,num_samples,-1),dim=1)
    num_test_samples = 1
    z_dim = 10

    z = generate_samples_z(z_dim, model.scorenet, 70*num_test_samples, u_train).detach().reshape(-1, z_dim)
    X = model.vae.decoder(z).detach()
    save_image(X.view(-1, 1, 32, 32), f'./samples/new_digit_train_70.png')

    z = generate_samples_z(z_dim, model.scorenet, 30*num_test_samples, u_test).detach().reshape(-1, z_dim)
    X = model.vae.decoder(z).detach()
    save_image(X.view(-1, 1, 32, 32), f'./samples/new_digit_test_70.png')

    torch.save(u_train, 'data/u_train_prototype.pth')
    torch.save(u_test, 'data/u_test_prototype.pth')
    print("Done")

def eval_accuracy(X_data, num_test_samples=1000):
    #classification accuracy
    _, _, _, Z = model.vae(X_data)
    u_train = torch.mean(Z.reshape(70,num_samples,-1),dim=1)
    _, _, _, Z = model.vae(X_test)
    u_test = torch.mean(Z.reshape(30,num_samples,-1),dim=1)
    u = u_train.unsqueeze(1).repeat(1,num_test_samples,1).reshape(-1,u_train.shape[-1])
    z = generate_samples_z(Z.shape[-1], model.scorenet, 70*num_test_samples, u).detach().reshape(-1, Z.shape[-1])
    X = model.vae.decoder(z).detach()
    generate_images(X.reshape(70,num_test_samples,1,32,32)[:,0,:,:,:].detach().cpu().reshape(-1, 1, 32, 32),f'./samples/train.png',indices)

    Y = np.tile(indices.reshape(70,1),(1,num_test_samples)).reshape(-1)
    train_accuracy = evaluate_accuracy(X,Y)

    u = u_test.unsqueeze(1).repeat(1,num_test_samples,1).reshape(-1,u_test.shape[-1])

    z = generate_samples_z(Z.shape[-1], model.scorenet, 30*num_test_samples, u).detach().reshape(-1, Z.shape[-1])
    X = model.vae.decoder(z).detach()
    generate_images(X.reshape(30,num_test_samples,1,32,32)[:,0,:,:,:].detach().cpu().reshape(-1, 1, 32, 32),f'./samples/test.png', remaining_indices)

    Y = np.tile(remaining_indices.reshape(30,1),(1,num_test_samples)).reshape(-1)
    test_accuracy = evaluate_accuracy(X,Y)
    #print("test classification accuracy", test_accuracy)
    return train_accuracy, test_accuracy


if training:
    for i_epoch in range(current_epochs,totol_epoch):
        for data in dloader: 
            data = data.reshape(-1,X_dim)
            opt.zero_grad()
            # training step
            loss, vae_loss, score_loss = score_operator_loss(model, data)
            loss.backward()
            opt.step()

        # print the training stats
        if i_epoch % 100 == 0:
            print(f"{i_epoch+1} (time:{time.time() - t0}s): Loss:{loss.item()}, vae_loss:{vae_loss.item()}, score_loss:{score_loss.item()}")
            
        if (1+i_epoch) % 1000 == 0:    
            tt = time.time() - t0
            train_accuracy, test_accuracy = eval_accuracy(X_data)
            print(f"{i_epoch+1} (time:{tt}s): Train_accuracy:{train_accuracy}, Test_accuracy:{test_accuracy}")
            stats["epoch"].append(1+i_epoch)
            stats["time"].append(tt)
            stats["train_accuracy"].append(train_accuracy)
            stats["test_accuracy"].append(test_accuracy)
            torch.save(stats, 'stats/end2end_'+str(num_samples)+'_'+str(seed)+'.pth')
            torch.save(model.state_dict(), 'weights/end2end_'+str(num_samples)+'_'+str(seed)+'.pth')

    torch.save(stats, 'stats/end2end_'+str(num_samples)+'_'+str(seed)+'.pth')
    torch.save(model.state_dict(), 'weights/end2end_'+str(num_samples)+'_'+str(seed)+'.pth')


