# test the few-shot capability of trained SNO with prototype-embedding method (Figure 6 in the paper)
# requring pretrained weights of SNO with prototype-embedding method in latent space (running SNO_prototype_latent.py)
import time
import numpy as np
import torch
from torch.utils.data import Dataset
import torchvision
import matplotlib.pyplot as plt

from dataprocess.data import training_idx
from utils.model import LatentScoreOperator
from utils.loss import score_condition_loss
from dataprocess.compute_u import compute_u_test
from utils.save_images import generate_images
from utils.generate import generate_training_samples, generate_samples_z
from torchvision.utils import save_image
from dataprocess.classifier import evaluate_accuracy
from dataprocess.data import MyDataset_XU2



seed = 0
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_data_path = "data/mnist2d_train.npy"
num_samples = 2000
num_test_samples = 1000
X_dim = 1024
# number of samples to generate 
num_new_samples = 6

num_few_shot_samples_list = [1,10,100,2000]
samples_list = []
test_accuracy_list = []
for num_few_shot_samples in num_few_shot_samples_list:
    X = np.load(train_data_path).astype(np.float32)/255
    X = X[:,:num_samples,:]
    indices = training_idx
    remaining_indices = np.setdiff1d(np.arange(100), indices)
    #print(indices)

    X_selected = X[indices]
    X_unselected = X[remaining_indices]
    X_unselected = X_unselected[:,:num_few_shot_samples,:]


    X_test = torch.from_numpy(X_unselected).reshape(-1, X_dim).to(device)

    model = LatentScoreOperator(x_dim=X_dim, h_dim1=512, h_dim2=256, z_dim=10, num_examples=70, num_samples=num_samples, u=True)
    model = model.to(device)



    model_state_dict = torch.load('weights/end2end_'+str(num_samples)+'_'+str(seed)+'.pth')
    model.load_state_dict(model_state_dict)

    _, _, _, Z = model.vae(X_test)
    # use few samples of test distribution to compute probability embedding 
    u_test = torch.mean(Z.reshape(30,num_few_shot_samples,-1),dim=1)

    test_idx = 11
    
    # generate distinct samples from test distribution
    u = u_test[test_idx].reshape(1,10).unsqueeze(1).repeat(1,num_new_samples,1).reshape(-1,10)
    z = generate_samples_z(10, model.scorenet, 1*num_new_samples, u).detach().reshape(-1,10)
    samples = model.vae.decoder(z).detach().cpu()
    samples_list.append(samples.reshape(-1,1,32,32))

    # evaluate classification accuracy
    u = u_test.unsqueeze(1).repeat(1,num_test_samples,1).reshape(-1,u_test.shape[-1])
    z = generate_samples_z(Z.shape[-1], model.scorenet, 30*num_test_samples, u).detach().reshape(-1, Z.shape[-1])
    X = model.vae.decoder(z).detach()
    Y = np.tile(remaining_indices.reshape(30,1),(1,num_test_samples)).reshape(-1)
    test_accuracy = evaluate_accuracy(X,Y)
    test_accuracy_list.append(round(test_accuracy,4))
    print("num few shot samples {}, test classification accuracy {}".format(num_few_shot_samples, test_accuracy) )


fig, axs = plt.subplots(2, 14, figsize=(18, 4))  

print(len(num_few_shot_samples_list))
for k in range(len(num_few_shot_samples_list)):
    num_few_shot_samples = num_few_shot_samples_list[k]
    test_accuracy = test_accuracy_list[k]
    row = k // 2
    col_start = (k % 2) * 7 

    axs[row, col_start].axis('off')
    axs[row, col_start].text(0.5, 0.6, f'K = {num_few_shot_samples}', fontsize=13, verticalalignment='center', horizontalalignment='center')

    axs[row, col_start].text(0.5, 0.4, f'ACC = {test_accuracy:.4f}', fontsize=10, verticalalignment='center', horizontalalignment='center')

    samples = samples_list[k]
    for i in range(num_new_samples):
        axs[row, col_start + i + 1].imshow(samples[i].squeeze(), cmap='gray')
        axs[row, col_start + i + 1].axis('off')


plt.subplots_adjust(wspace=0.05, hspace=0.05) 
plt.tight_layout()
plt.savefig('samples/fewshot.png', bbox_inches='tight', pad_inches=0)
