# Conditional Score-based generative model 1 
# trained on 10000 epochs with 70 training sets (2000 samples each) and finetuned on 30 testing sets (2000 sample each)
import time
import numpy as np
import torch
from torch.utils.data import Dataset
import torchvision
import matplotlib.pyplot as plt

from dataprocess.data import training_idx, MyDataset_XU
from utils.model import LatentScoreOperator
from utils.loss import score_condition_loss
from utils.pca import compute_u_train_and_test
from utils.generate import generate_training_samples, generate_samples_z
from torchvision.utils import save_image
from dataprocess.classifier import evaluate_accuracy

seed = 0
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
t0 = time.time()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_data_path = "data/mnist2d_train.npy"
num_samples = 2000
X_dim = 1024

X = np.load(train_data_path).astype(np.float32)/255
X = X[:,:num_samples,:]
indices = training_idx
remaining_indices = np.setdiff1d(np.arange(100), indices)
print(indices)
X_selected = X[indices]
X_unselected = X[remaining_indices]

X_data = torch.from_numpy(X_selected).reshape(-1, X_dim).to(device)
X_test = torch.from_numpy(X_unselected).reshape(-1, X_dim).to(device)
Y = torch.from_numpy(np.linspace(0, 99, 100)).unsqueeze(1).repeat(1, num_samples).float().numpy() / 100.
Y_selected = Y[indices]
Y_unselected = Y[remaining_indices]

Y_data = torch.from_numpy(Y_selected).reshape(-1, 1)
Y_test = torch.from_numpy(Y_unselected).reshape(-1, 1)

dataset = MyDataset_XU(X_data, Y_data)
dloader_train = torch.utils.data.DataLoader(dataset, batch_size=1000, shuffle=True)

dataset = MyDataset_XU(X_test, Y_test)
dloader_test = torch.utils.data.DataLoader(dataset, batch_size=1000, shuffle=True)

model = LatentScoreOperator(x_dim=X_dim, h_dim1=512, h_dim2=256, z_dim=10, num_examples=70, num_samples=num_samples, u_dim=1, u=True, condition=True)

# start the training loop

opt = torch.optim.AdamW(model.parameters(), lr=3e-4)

model = model.to(device)
totol_epoch_1, totol_epoch_2 = 10000, 10000
mode = 0
if mode == 0:
    # first time to train
    pretrained = False
    training = True
    finetune = True




epochs, total_time, train_acc, test_acc = [],[],[],[]
stats = {"epoch":epochs, "time":total_time, "train_accuracy":train_acc, "test_accuracy":test_acc}
t0 = time.time()
 

if pretrained:
    model_state_dict = torch.load('weights/score_'+str(num_samples)+'_'+str(seed)+'.pth')
    model.load_state_dict(model_state_dict)

def eval_accuracy(num_test_samples=1000):
    #classification accuracy
    u = torch.from_numpy(np.tile(indices.reshape(70,1),(1,num_test_samples)).reshape(-1)).float()/ 100.
    z = generate_samples_z(10, model.scorenet, 70*num_test_samples, u=u).detach().reshape(-1, 10)
    X = model.vae.decoder(z).detach()
    save_image(X.reshape(70,num_test_samples,1,32,32)[:,0,:,:,:].view(-1, 1, 32, 32), f'./samples/train.png')
    Y = np.tile(indices.reshape(70,1),(1,num_test_samples)).reshape(-1)
    train_accuracy = evaluate_accuracy(X,Y)
    u = torch.from_numpy(np.tile(remaining_indices.reshape(30,1),(1,num_test_samples)).reshape(-1)).float()/ 100.
    z = generate_samples_z(10, model.scorenet, 30*num_test_samples, u).detach().reshape(-1, 10)
    X = model.vae.decoder(z).detach()

    save_image(X.reshape(30,num_test_samples,1,32,32)[:,0,:,:,:].view(-1, 1, 32, 32), f'./samples/test.png')

    Y = np.tile(remaining_indices.reshape(30,1),(1,num_test_samples)).reshape(-1)
    test_accuracy = evaluate_accuracy(X,Y)
    return train_accuracy, test_accuracy

# train  70 digits
if training:
    for i_epoch in range(totol_epoch_1):
        for data, u in dloader_train:
            opt.zero_grad()
            u = u.to(device)
            loss, vae_loss, score_loss = score_condition_loss(model, data, u)
            loss.backward()
            opt.step()

        # print the training stats
        if i_epoch % 100 == 0:
            print(f"{i_epoch+1} (time:{time.time() - t0}s): Loss:{loss.item()}, vae_loss:{vae_loss.item()}, score_loss:{score_loss.item()}")
            
        if (1+i_epoch) % 1000 == 0:    
            tt = time.time() - t0
            train_accuracy, test_accuracy = eval_accuracy()
            print(f"{i_epoch+1} (time:{tt}s): Train_accuracy:{train_accuracy}, Test_accuracy:{test_accuracy}")
            stats["epoch"].append(1+i_epoch)
            stats["time"].append(tt)
            stats["train_accuracy"].append(train_accuracy)
            stats["test_accuracy"].append(test_accuracy)
            torch.save(stats, 'stats/score_'+str(num_samples)+'_'+str(seed)+'.pth')

    torch.save(model.state_dict(), 'weights/score_'+str(num_samples)+'_'+str(seed)+'.pth')
    
# finetune 30 digits
if finetune:
    for i_epoch in range(totol_epoch_1, totol_epoch_1+totol_epoch_2):
        for data, u in dloader_test:
            opt.zero_grad()
            # training step
            u = u.to(device)
            loss, vae_loss, score_loss = score_condition_loss(model, data, u)
            loss.backward()
            opt.step()

        # print the training stats
        if i_epoch % 100 == 0:
            print(f"{i_epoch+1} (time:{time.time() - t0}s): Loss:{loss.item()}, vae_loss:{vae_loss.item()}, score_loss:{score_loss.item()}")
            
        if (1+i_epoch) % 100 == 0:    
            tt = time.time() - t0
            train_accuracy, test_accuracy = eval_accuracy()
            print(f"{i_epoch+1} (time:{tt}s): Train_accuracy:{train_accuracy}, Test_accuracy:{test_accuracy}")
            stats["epoch"].append(1+i_epoch)
            stats["time"].append(tt)
            stats["train_accuracy"].append(train_accuracy)
            stats["test_accuracy"].append(test_accuracy)
            torch.save(stats, 'stats/score_'+str(num_samples)+'_'+str(seed)+'.pth')
    
    torch.save(stats, 'stats/score_'+str(num_samples)+'_'+str(seed)+'.pth')
    torch.save(model.state_dict(), 'weights/score_'+str(num_samples)+'_'+str(seed)+'.pth')


