import argparse
import torch
import random
import numpy as np
from tqdm import tqdm
from dataset_mine import load_dataset
from numpy.random import RandomState
from visualizer import Visualizer
from metrics import (knn_classify,
                     mean_squared_error,
                     r_squared)
import matplotlib.pylab as plt
from models.NG_GPLVM import NG_GPLVM
def save_models(model, optimizer, epoch, losses, result_dir, data_name, save_model=True):
    '''

    Parameters
    ----------
    model
    optimizer
    epoch
    losses
    result_dir  :           result saving path
    data_name   :           data name
    jj          :           number of experiment repetition
    save_model  :           indication if to saving model

    Returns
    -------

    '''
    state = {'model': model.state_dict(),
             'optimizer': optimizer.state_dict(),
             'epoch': epoch,
             'losses': losses}
    if save_model:
        log_dir = result_dir + f"{data_name}_epoch{epoch}.pt"
        torch.save(state, log_dir)

random_seed = 8
def reset_seed(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    # 为CPU设置种子用于生成随机数，以使得结果是确定的
    torch.manual_seed(seed)
    # torch.cuda.manual_seed()为当前GPU设置随机种子
    torch.cuda.manual_seed(seed)

reset_seed(random_seed)
#device = "cuda:1"
device = 'cuda'

""" ##-------------------------------- data loader ------------------------------------------------"""
# Load Dataset
rng = RandomState(random_seed)

## ['bridges', 'congress', 's-curve', 'mnist', 'pm25', 'exchange', 'fiji', 'highschool','cifar', 'cmu', 'cmu1', 'cmu2', 'cmu3', 'cmu4', 'hippo', 'montreal', 'newsgroups',
# 'simdata1','spam', 'spikes', 'yale']

test_split = 0
ds = load_dataset(rng, 'mnist', 'gaussian', test_split=test_split)

""" ##-------------------------------- Parameters Settings ------------------------------------------------"""
setting_dict = {}
setting_dict['num_m'] = 2            # if num_m = 1, it is using SE kernel
setting_dict['num_sample_pt'] = 50
setting_dict['num_total_pt'] = setting_dict['num_m'] * setting_dict['num_sample_pt']
setting_dict['num_batch'] = 1
setting_dict['lr_hyp'] = .01
setting_dict['iter'] = 10000
setting_dict['num_repexp'] = 1
setting_dict['kl_option'] = True  # if adding X regularization in loss function
setting_dict['noise_err'] = 100.0
setting_dict['latent_dim'] = ds.latent_dim
setting_dict['N'] = ds.Y.shape[0]

if setting_dict['num_m'] ==1:
    model_name = f"ngsmRFLVM_SE_{setting_dict['num_sample_pt']}"
else:
    model_name = f"ngsmRFLVM_{setting_dict['num_m']}_{setting_dict['num_sample_pt']}"
res_dir = f'/mnt/NG-MV-RFLVM/SingleViewresults/{model_name}/{ds.name}/'
#viz = Visualizer(res_dir+'figures', ds)

if test_split is None:
    Y = ds.Y / np.linalg.norm(ds.Y,2)
    Y_original = Y
else:
    if test_split==0:
        Y = ds.Y / 255  # np.linalg.norm(ds.Y,2)
        Y_original = Y
    else:
        Y = ds.Y_ma / 255 # np.linalg.norm(ds.Y_ma)
        Y_original = Y
# Y = ds.Y
print(Y.shape)
setting_dict['noise_err'] = .05 * Y.std()
acc = []

for loop in range(5):
    print(f'\n It is the: {loop+1} loop' )
    GPLVM_model = NG_GPLVM(setting_dict['num_batch'],
                              setting_dict['num_sample_pt'],
                              setting_dict,
                              Y,
                              device=device,ifPCA=True).to(device)

    optimizer = torch.optim.Adam(GPLVM_model.parameters(), lr=setting_dict['lr_hyp'])
    epochs_iter = tqdm(range(setting_dict['iter'] + 1), desc="Epoch")
    for i in epochs_iter:

        GPLVM_model.train()

        optimizer.zero_grad()

        losstotal = GPLVM_model.compute_loss(batch_y=Y, kl_option=setting_dict['kl_option'])
        losstotal.backward()
        optimizer.step()

        if i % 500 == 0:
            print(f'\nELBO: {losstotal.item()}')
            print(f"X_KL: {GPLVM_model._kl_div_qp().item()}")
            F, K = GPLVM_model.f_eval(batch_y=Y_original, x_star=None)
            F = F.cpu().detach().numpy()  # shape: N_star * obs_dim
            K = K.cpu().detach().numpy()

            # Log metrics.
            # ------------
            mse_Y = mean_squared_error(F, ds.Y / 255)
            print(f'MSE Y:  {mse_Y}')
            knn_acc = knn_classify(GPLVM_model.mu_x.cpu().detach().numpy(), ds.labels, rng)
            print('\nKNN acc', knn_acc)

            if ds.has_true_F:
                mse_F = mean_squared_error(F, ds.F)
                print(f'MSE F:  {mse_F}')

            if ds.has_true_K:
                mse_K = mean_squared_error(K, ds.K)
                print(f'MSE K:  {mse_K}')

            if ds.has_true_X:
                r2_X = r_squared(GPLVM_model.mu_x.cpu().detach().numpy(), ds.X)
                print(f'R2 X: {r2_X}')

            print("\n")

    knn_acc = knn_classify(GPLVM_model.mu_x.cpu().detach().numpy(), ds.labels, rng)
    acc.append(knn_acc)
    for i, item in enumerate(acc):
        print(f"Loop {i + 1}: {item}")

    knn_accY = knn_classify(Y, ds.labels, rng)
    print('\nKNN acc Y', knn_accY)

print("Final list:")
for i, item in enumerate(acc):
    print(f"Loop {i + 1}: {item}")







