# Score Neural Operator with prototype embedding method in 1024 dim pixel space
import time
import numpy as np
import torch
from torch.utils.data import Dataset
import torchvision
import matplotlib.pyplot as plt

from dataprocess.data import MyDataset_XU, training_idx
from utils.model import ScoreNetwork, VAE
from utils.loss import sm_loss, vae_loss_fn
from utils.generate import generate_samples
from torchvision.utils import save_image
from dataprocess.compute_u import compute_u_train_test
from utils.save_images import generate_images
from dataprocess.classifier import evaluate_accuracy
np.random.seed(42)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_data_path = "data/mnist2d_train.npy"
num_samples = 2000
X_dim = 1024
indices = training_idx
remaining_indices = np.setdiff1d(np.arange(100), indices)
print(indices)

X = np.load(train_data_path).astype(np.float32)/255
X = torch.from_numpy(X[:,:num_samples,:]).float().reshape(100,num_samples,1,32,32).to(device)
# X train
X_selected = X[indices]
# X test
X_unselected = X[remaining_indices]
X_VAE_Train = X_selected.reshape(-1,1024)
X = X.reshape(100, num_samples, 1024)

current_epochs, totol_epoch = 0, 20000
mode = 0
if mode == 0:
    # first time to train
    pretrained_score = False
    train_score = True
    pretrained_vae = False
    train_vae = True
    load_stat = False

elif mode == 1:
    # finetune
    pretrained_score = True
    train_score = True
    pretrained_vae = True
    train_vae = False
    load_stat = True



  
# train a separated VAE to derive probability mebedding


vae = VAE(x_dim=32*32, h_dim1= 512, h_dim2=256, z_dim=10).cuda()
if pretrained_vae:
    model_state_dict = torch.load('weights/operator_ori_vae_'+str(num_samples)+'.pth')
    vae.load_state_dict(model_state_dict)

if train_vae:
    opt_vae = torch.optim.AdamW(vae.parameters(), lr=1e-3)
    for k in range(10000):
        opt_vae.zero_grad()
        recon_x, mu, log_var, _ = vae(X_VAE_Train)
        loss = vae_loss_fn(recon_x, X_VAE_Train, mu, log_var)
        loss.backward()
        opt_vae.step()
        # print the training stats
        if k % 100 == 0:
            print(f"{k} Loss:{loss.detach().item()}")

    torch.save(vae.state_dict(), 'weights/operator_ori_vae_'+str(num_samples)+'.pth')

# compute probability embedding u on train and test sets
_, _, _, u_train = vae(X_selected)
u_train = torch.mean(u_train.reshape(70,num_samples,10), dim=1)
_, _, _, u_test = vae(X_unselected)
u_test = torch.mean(u_test.reshape(30,num_samples,10), dim=1)


U_data = u_train.unsqueeze(1).repeat(1, num_samples, 1)

X_data = X_selected.reshape(-1, X.shape[-1]).detach()
U_data = U_data.reshape(-1, U_data.shape[-1]).detach()

dataset = MyDataset_XU(X_data, U_data)
score_network = ScoreNetwork()

# start the training loop

opt = torch.optim.AdamW(score_network.parameters(), lr=1e-3)
dloader = torch.utils.data.DataLoader(dataset, batch_size=512, shuffle=True)
device = torch.device('cuda:0')  
score_network = score_network.to(device)



if pretrained_score:
    model_state_dict = torch.load('weights/ori_vae_operator_score_'+str(num_samples)+'.pth')
    score_network.load_state_dict(model_state_dict)
    

if load_stat:
    stats = torch.load('stats/ori_vae_operator_'+str(num_samples)+'.pth')
    current_epochs = int(stats["epoch"][-1])
    t0 = time.time()  - stats["time"][-1]
    totol_epoch = 20000 - current_epochs
else:
    t0 = time.time()
    epochs, total_time, train_acc, test_acc = [],[],[],[]
    stats = {"epoch":epochs, "time":total_time, "train_accuracy":train_acc, "test_accuracy":test_acc}


#classification accuracy
def eval_accuracy(num_test_samples = 1000, num_batch = 10):
    u = u_train.unsqueeze(1).repeat(1,num_test_samples//num_batch,1).reshape(-1,u_train.shape[-1])
    X = torch.cat([generate_samples(score_network, 70*num_test_samples//num_batch, u).detach() for _ in range(num_batch)], dim=0)
    save_image(X.reshape(70,num_test_samples,1,32,32)[:,0,:,:,:].view(-1, 1, 32, 32), f'./samples/train.png')
    Y = np.concatenate([np.tile(indices.reshape(70,1),(1,num_test_samples//num_batch)).reshape(1,-1)  for _ in range(num_batch)], axis=0).reshape(-1)
    train_accuracy = evaluate_accuracy(X,Y)


    u = u_test.unsqueeze(1).repeat(1,num_test_samples//num_batch,1).reshape(-1,u_test.shape[-1])
    X = torch.cat([generate_samples(score_network, 30*num_test_samples//num_batch, u).detach() for _ in range(num_batch)], dim=0)

    save_image(X.reshape(30,num_test_samples,1,32,32)[:,0,:,:,:].view(-1, 1, 32, 32), f'./samples/test.png')
    Y = np.concatenate([np.tile(remaining_indices.reshape(30,1),(1,num_test_samples//num_batch)).reshape(1,-1)  for _ in range(num_batch)], axis=0).reshape(-1)
    test_accuracy = evaluate_accuracy(X,Y)
    return train_accuracy, test_accuracy


if train_score:
    for i_epoch in range(current_epochs, current_epochs+totol_epoch):
        for data, u in dloader:  
            data = data.reshape(data.shape[0], -1).to(device)
            u = u.to(device)
            opt.zero_grad()

            # training step
            loss = sm_loss(score_network, data, u)
            loss.backward()
            opt.step()
  

        # print the training stats
        if (i_epoch) % 100 == 0:
            print(f"{i_epoch} ({time.time() - t0}s): Loss:{loss.detach().item()}")
            
        if (i_epoch+1) % 1000 == 0:
            tt = time.time() - t0
            train_accuracy, test_accuracy = eval_accuracy()
            print(f"{i_epoch+1} (time:{tt}s): Train_accuracy:{train_accuracy}, Test_accuracy:{test_accuracy}")
            stats["epoch"].append(1+i_epoch)
            stats["time"].append(tt)
            stats["train_accuracy"].append(train_accuracy)
            stats["test_accuracy"].append(test_accuracy)
            torch.save(stats, 'stats/ori_vae_operator_'+str(num_samples)+'.pth')



    torch.save(stats, 'stats/ori_vae_operator_'+str(num_samples)+'.pth')
    torch.save(score_network.state_dict(), 'weights/ori_vae_operator_score_'+str(num_samples)+'.pth')

