# Score Neural Operator with KME embedding method in 1024 dim pixel space

import time
import numpy as np
import torch
from torch.utils.data import Dataset
import torchvision
import matplotlib.pyplot as plt

from dataprocess.data import MyDataset_XU, training_idx
from utils.model import ScoreNetwork
from utils.loss import sm_loss
from utils.generate import generate_samples
from torchvision.utils import save_image
from dataprocess.compute_u import compute_u_train_test
from utils.save_images import generate_images
from dataprocess.classifier import evaluate_accuracy
np.random.seed(42)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_data_path = "data/mnist2d_train.npy"
num_samples = 2000
X_dim = 1024
indices = training_idx
remaining_indices = np.setdiff1d(np.arange(100), indices)
print(indices)

X = np.load(train_data_path).astype(np.float32)/255
X = torch.from_numpy(X[:,:num_samples,:]).float().reshape(100,num_samples,1,32,32).to(device)
X = X.reshape(100, num_samples, 1024)


mode = 0
if mode == 0:
    # first time to train
    pretrained_score = False
    train_score = True
    load_u_train = False
    load_stat = False
elif mode == 1:
    # finetune
    pretrained_score = True
    train_score = True
    load_u_train = True
    load_stat = True





if load_u_train:
    u_train = torch.load('weights/ori_u_train_'+str(num_samples)+'.pth')
    u_test = torch.load('weights/ori_u_test_'+str(num_samples)+'.pth')
else:
    u_train, u_test = compute_u_train_test(X)
    torch.save(u_train, 'weights/ori_u_train_'+str(num_samples)+'.pth')
    torch.save(u_test, 'weights/ori_u_test_'+str(num_samples)+'.pth')
print(u_train.shape,u_test.shape)


U_data = u_train.unsqueeze(1).repeat(1, num_samples, 1)

X_selected = X[indices]
X_unselected = X[remaining_indices]


X_data = X_selected.reshape(-1, X.shape[-1]).detach()
U_data = U_data.reshape(-1, U_data.shape[-1]).detach()

dataset = MyDataset_XU(X_data, U_data)
score_network = ScoreNetwork()

# start the training loop

opt = torch.optim.AdamW(score_network.parameters(), lr=3e-4)
dloader = torch.utils.data.DataLoader(dataset, batch_size=512, shuffle=True)
device = torch.device('cuda:0')  # change this if you don't have a gpu
score_network = score_network.to(device)


current_epochs, totol_epoch = 0, 20000

if pretrained_score:
    model_state_dict = torch.load('weights/ori_operator_score_'+str(num_samples)+'.pth')
    score_network.load_state_dict(model_state_dict)
    

if load_stat:
    stats = torch.load('stats/ori_operator_'+str(num_samples)+'.pth')
    current_epochs = int(stats["epoch"][-1])
    t0 = time.time() -  stats['time'][-1] 
    totol_epoch = 20000 - current_epochs
else:
    epochs, total_time, train_acc, test_acc = [],[],[],[]
    stats = {"epoch":epochs, "time":total_time, "train_accuracy":train_acc, "test_accuracy":test_acc}
    t0 = time.time()
    totol_epoch = 20000
#classification accuracy
def eval_accuracy(num_test_samples = 1000, num_batch = 20):
    u = u_train.unsqueeze(1).repeat(1,num_test_samples//num_batch,1).reshape(-1,u_train.shape[-1])
    X = torch.cat([generate_samples(score_network, 70*num_test_samples//num_batch, u).detach() for _ in range(num_batch)], dim=0)
    save_image(X.reshape(70,num_test_samples,1,32,32)[:,0,:,:,:].view(-1, 1, 32, 32), f'./samples/train.png')
    Y = np.concatenate([np.tile(indices.reshape(70,1),(1,num_test_samples//num_batch)).reshape(1,-1)  for _ in range(num_batch)], axis=0).reshape(-1)
    train_accuracy = evaluate_accuracy(X,Y)

    u = u_test.unsqueeze(1).repeat(1,num_test_samples//num_batch,1).reshape(-1,u_test.shape[-1])
    X = torch.cat([generate_samples(score_network, 30*num_test_samples//num_batch, u).detach() for _ in range(num_batch)], dim=0)

    save_image(X.reshape(30,num_test_samples,1,32,32)[:,0,:,:,:].view(-1, 1, 32, 32), f'./samples/test.png')
    Y = np.concatenate([np.tile(remaining_indices.reshape(30,1),(1,num_test_samples//num_batch)).reshape(1,-1)  for _ in range(num_batch)], axis=0).reshape(-1)
    test_accuracy = evaluate_accuracy(X,Y)

    return train_accuracy, test_accuracy


if train_score:
    for i_epoch in range(current_epochs, current_epochs+totol_epoch):

        for data, u in dloader: 
            data = data.reshape(data.shape[0], -1).to(device)
            u = u.to(device)
            opt.zero_grad()

            # training step
            loss = sm_loss(score_network, data, u)
            loss.backward()
            opt.step()
  

        # print the training stats
        if (i_epoch) % 100 == 0:
            print(f"{i_epoch} ({time.time() - t0}s): Loss:{loss.detach().item()}")
            
        if (i_epoch+1) % 1000 == 0:
            tt = time.time() - t0
            train_accuracy, test_accuracy = eval_accuracy()
            print(f"{i_epoch+1} (time:{tt}s): Train_accuracy:{train_accuracy}, Test_accuracy:{test_accuracy}")
            stats["epoch"].append(1+i_epoch)
            stats["time"].append(tt)
            stats["train_accuracy"].append(train_accuracy)
            stats["test_accuracy"].append(test_accuracy)
            torch.save(stats, 'stats/ori_operator_'+str(num_samples)+'.pth')

  

    torch.save(stats, 'stats/ori_operator_'+str(num_samples)+'.pth')
    torch.save(score_network.state_dict(), 'weights/ori_operator_score_'+str(num_samples)+'.pth')

