# Conditional Score-based generative model 2 
# trained on 10000 epochs with 70 training sets (2000 samples each) and finetuned on 30 testing sets (1 sample each)
import time
import numpy as np
import torch
from torch.utils.data import Dataset
import torchvision
import matplotlib.pyplot as plt

from dataprocess.data import training_idx, MyDataset_XU
from utils.model import LatentScoreOperator
from utils.loss import score_condition_loss
from utils.pca import compute_u_train_and_test
from utils.generate import generate_training_samples, generate_samples_z
from torchvision.utils import save_image
from dataprocess.classifier import evaluate_accuracy

seed = 0
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
t0 = time.time()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_data_path = "data/mnist2d_train.npy"
num_samples = 2000
X_dim = 1024

X = np.load(train_data_path).astype(np.float32)/255
X = X[:,:num_samples,:]
indices = training_idx
remaining_indices = np.setdiff1d(np.arange(100), indices)
print(indices)
X_selected = X[indices]
# 1 sample per test distribution
X_unselected = X[remaining_indices][:,0,:]
print(X_selected.shape,X_unselected.shape)

X_data = torch.from_numpy(X_selected).reshape(-1, X_dim).to(device)
X_test = torch.from_numpy(X_unselected).reshape(-1, X_dim).to(device)
Y = torch.from_numpy(np.linspace(0, 99, 100)).unsqueeze(1).repeat(1, num_samples).float().numpy() / 100.
Y_selected = Y[indices]

Y_data = torch.from_numpy(Y_selected).reshape(-1, 1)
Y_test = torch.from_numpy((torch.from_numpy(np.linspace(0, 99, 100)).unsqueeze(1).float().numpy() / 100.)[remaining_indices]).reshape(-1, 1)

dataset = MyDataset_XU(X_data, Y_data)
dloader_train = torch.utils.data.DataLoader(dataset, batch_size=1000, shuffle=True)

dataset = MyDataset_XU(X_test, Y_test)
dloader_test = torch.utils.data.DataLoader(dataset, batch_size=10, shuffle=True)

model = LatentScoreOperator(x_dim=X_dim, h_dim1=512, h_dim2=256, z_dim=10, num_examples=70, num_samples=num_samples, u_dim=1, u=True, condition=True)

# start the training loop

opt = torch.optim.AdamW(model.parameters(), lr=3e-4)

model = model.to(device)
totol_epoch_1, totol_epoch_2 = 10000, 10000
mode = 0
if mode == 0:
    # first time to train
    pretrained = False
    training = True
    finetune = True




epochs, total_time, train_acc, test_acc = [],[],[],[]
stats = {"epoch":epochs, "time":total_time, "train_accuracy":train_acc, "test_accuracy":test_acc}
t0 = time.time()
 

if pretrained:
    model_state_dict = torch.load('weights/score2_'+str(num_samples)+'_'+str(seed)+'.pth')
    model.load_state_dict(model_state_dict)

def eval_accuracy(num_test_samples=1000):
    #classification accuracy
    u = torch.from_numpy(np.tile(indices.reshape(70,1),(1,num_test_samples)).reshape(-1)).float()/ 100.
    z = generate_samples_z(10, model.scorenet, 70*num_test_samples, u=u).detach().reshape(-1, 10)
    X = model.vae.decoder(z).detach()

    Y = np.tile(indices.reshape(70,1),(1,num_test_samples)).reshape(-1)
    train_accuracy = evaluate_accuracy(X,Y)

    u = torch.from_numpy(np.tile(remaining_indices.reshape(30,1),(1,num_test_samples)).reshape(-1)).float()/ 100.
    z = generate_samples_z(10, model.scorenet, 30*num_test_samples, u).detach().reshape(-1, 10)
    X = model.vae.decoder(z).detach()

    save_image(X.reshape(30,num_test_samples,1,32,32)[:,0,:,:,:].view(-1, 1, 32, 32), f'./samples/test.png')

    Y = np.tile(remaining_indices.reshape(30,1),(1,num_test_samples)).reshape(-1)
    test_accuracy = evaluate_accuracy(X,Y)
    return train_accuracy, test_accuracy

def generate_images(epoch, num_test_samples=50):

    u = torch.from_numpy(np.tile(remaining_indices[0].reshape(1,1),(1,num_test_samples)).reshape(-1)).float()/ 100.
    z = generate_samples_z(10, model.scorenet, 1*num_test_samples, u).detach().reshape(-1, 10)
    X = model.vae.decoder(z).detach()

    save_image(X.reshape(1,num_test_samples,1,32,32).view(-1, 1, 32, 32), f'./samples/finetune_'+str(epoch)+'.png')

# train  70 digits
if training:
    for i_epoch in range(totol_epoch_1):
        for data, u in dloader_train:
            opt.zero_grad()
            # training step
            u = u.to(device)
            loss, vae_loss, score_loss = score_condition_loss(model, data, u)
            loss.backward()
            opt.step()

        # print the training stats
        if i_epoch % 100 == 0:
            print(f"{i_epoch+1} (time:{time.time() - t0}s): Loss:{loss.item()}, vae_loss:{vae_loss.item()}, score_loss:{score_loss.item()}")
            
        if (1+i_epoch) % 100 == 0:    
            tt = time.time() - t0
            train_accuracy, test_accuracy = eval_accuracy()
            print(f"{i_epoch+1} (time:{tt}s): Train_accuracy:{train_accuracy}, Test_accuracy:{test_accuracy}")
            stats["epoch"].append(1+i_epoch)
            stats["time"].append(tt)
            stats["train_accuracy"].append(train_accuracy)
            stats["test_accuracy"].append(test_accuracy)
            torch.save(stats, 'stats/score2_'+str(num_samples)+'_'+str(seed)+'.pth')

    torch.save(model.state_dict(), 'weights/score2_'+str(num_samples)+'_'+str(seed)+'.pth')
# finetune 30 digits
if finetune:
    for i_epoch in range(totol_epoch_1, totol_epoch_1+totol_epoch_2):
        # for data, _ in dloader:  # we don't need the data class
        for data, u in dloader_test:
            opt.zero_grad()
            # training step
            u = u.to(device)
            loss, vae_loss, score_loss = score_condition_loss(model, data, u)
            loss.backward()
            opt.step()

        # print the training stats
        if i_epoch % 100 == 0:
            print(f"{i_epoch+1} (time:{time.time() - t0}s): Loss:{loss.item()}, vae_loss:{vae_loss.item()}, score_loss:{score_loss.item()}")
            
        if (1+i_epoch) % 100 == 0:    
            tt = time.time() - t0
            train_accuracy, test_accuracy = eval_accuracy()
            print(f"{i_epoch+1} (time:{tt}s): Train_accuracy:{train_accuracy}, Test_accuracy:{test_accuracy}")
            stats["epoch"].append(1+i_epoch)
            stats["time"].append(tt)
            stats["train_accuracy"].append(train_accuracy)
            stats["test_accuracy"].append(test_accuracy)
            torch.save(stats, 'stats/score2_'+str(num_samples)+'_'+str(seed)+'.pth')
    
    torch.save(stats, 'stats/score2_'+str(num_samples)+'_'+str(seed)+'.pth')
    torch.save(model.state_dict(), 'weights/score2_'+str(num_samples)+'_'+str(seed)+'.pth')


