import argparse
import torch
import random
import numpy as np
from tqdm import tqdm
from dataset_mine import load_dataset
from numpy.random import RandomState
from visualizer import Visualizer
from sklearn.svm import SVC
from metrics import (knn_classify,
                     mean_squared_error,
                     r_squared)
import matplotlib.pylab as plt

#from utility.eval_metric import _evaluate_metric
from models.NG_GPLVM import NG_GPLVM
from models.NG_MVLVM import NG_MVLVM
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
def save_models(model, optimizer,loop, epoch, losses, result_dir, data_name, save_model=True):
    '''

    Parameters
    ----------
    model
    optimizer
    epoch
    losses
    result_dir  :           result saving path
    data_name   :           data name
    jj          :           number of experiment repetition
    save_model  :           indication if to saving model

    Returns
    -------

    '''
    state = {'model': model.state_dict(),
             'optimizer': optimizer.state_dict(),
             'epoch': epoch,
             'losses': losses}
    if save_model:
        log_dir = result_dir + f"{data_name}_loop{loop}_epoch{epoch}.pt"
        import os
        if not os.path.exists(result_dir):
            os.makedirs(result_dir)
        torch.save(state, log_dir)

random_seed = 8
def reset_seed(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

reset_seed(random_seed)
#device = "cuda:1"
device = 'cuda'
""" ##-------------------------------- data loader ------------------------------------------------"""
# Load Dataset
rng = RandomState(random_seed)

## ['bridges', 'congress', 's-curve', 'mnist', 'pm25', 'exchange', 'fiji', 'highschool','cifar', 'cmu', 'cmu1', 'cmu2', 'cmu3', 'cmu4', 'hippo', 'montreal', 'newsgroups',
# 'simdata1','spam', 'spikes', 'yale']

test_split = 0
ds = load_dataset(rng, 'mnist', 'gaussian', test_split=test_split)

""" ##-------------------------------- Parameters Settings ------------------------------------------------"""
setting_dict = {}
setting_dict['num_m'] = [2,2]            # if num_m = 1, it is using SE kernel
setting_dict['num_sample_pt'] = 50
setting_dict['num_total_pt'] =[setting_dict['num_m'][0] * setting_dict['num_sample_pt'],
                               setting_dict['num_m'][1] * setting_dict['num_sample_pt']]
setting_dict['num_batch'] = 1
setting_dict['lr_hyp'] = .005
setting_dict['iter'] = 10000
setting_dict['num_repexp'] = 1
setting_dict['kl_option'] = True  # if adding X regularization in loss function
setting_dict['noise_err'] = 100.0
setting_dict['latent_dim'] = ds.latent_dim
print('ds.latent_dim:',ds.latent_dim)
setting_dict['N'] = ds.Y.shape[0]


model_name = f"NG_MVLVM{setting_dict['latent_dim']}_{setting_dict['num_m']}_{setting_dict['num_sample_pt']}"
res_dir = f'/mnt/NG_MVLVM/MVresults/{model_name}/{ds.name}/'

if test_split is None:
    Y = ds.Y / np.linalg.norm(ds.Y,2)
    Y_original = Y
else:
    if test_split==0:
        Y = ds.Y / 255  # np.linalg.norm(ds.Y,2)
        Y_original = Y
    else:
        Y = ds.Y_ma / 255 # np.linalg.norm(ds.Y_ma)
        Y_original = Y
train_data = torch.tensor(Y)
labels = torch.tensor(ds.labels)
train_labels = labels.unsqueeze(1).long()
print("max labels",max(labels))
batch_size_train = train_labels.size(0)
num_classes = int(max(labels)) + 1
train_labels_one_hot = torch.zeros(batch_size_train, num_classes)
train_labels_one_hot.scatter_(1, train_labels, 1)

# Check the new shapes
print('\ntrain_data.shape ',train_data.shape)  # Should print torch.Size([6000, 784])

knn_acctrain = knn_classify(train_data, train_labels, rng)
print('\nknn_acctrain ', knn_acctrain)

X = train_data
log_reg_sim = LogisticRegression(max_iter=1000, solver='lbfgs', multi_class='multinomial')
log_reg_sim.fit(X, labels)
# Predict on the test set
y_pred = log_reg_sim.predict(X)
# Calculate the accuracy
accuracy = accuracy_score(labels, y_pred)
# Display results
print("linear regression on orginal data : ", accuracy)

setting_dict['noise_err'] = .05 * Y.std()
accknn = []
acclr = []
mse_Flist =[]

Y = [train_data,train_labels_one_hot]
for loop in range(5):
    print(f'\n It is the: {loop+1} loop')
    GPLVM_model = NG_MVLVM(setting_dict['num_batch'],
                              setting_dict['num_sample_pt'],
                              setting_dict,
                              Y,
                              device=device,ifPCA=True).to(device)
    optimizer = torch.optim.Adam(GPLVM_model.parameters(), lr=setting_dict['lr_hyp'])


    epochs_iter = tqdm(range(setting_dict['iter'] + 1), desc="Epoch")
    for i in epochs_iter:
        GPLVM_model.train()
        optimizer.zero_grad()
        losstotal = GPLVM_model.compute_loss(batch_y=Y, kl_option=setting_dict['kl_option'])
        losstotal.backward()
        optimizer.step()

        if i % 500 == 0:
            print(f'\nELBO: {losstotal.item()}')
            print(f"X_KL: {GPLVM_model._kl_div_qp().item()}")
            save_models(model=GPLVM_model, optimizer=optimizer, loop = loop, epoch=i, losses=losstotal,
                         result_dir=res_dir, data_name=ds.name, save_model=True)

            # Log metrics.
            # ------------
            knn_acc = knn_classify(GPLVM_model.mu_x.cpu().detach().numpy(), labels, rng)
            print('\nKNN acc', knn_acc)

            X = GPLVM_model.mu_x.cpu().detach().numpy()
            log_reg_sim = LogisticRegression(max_iter=1000, solver='lbfgs', multi_class='multinomial')
            log_reg_sim.fit(X, labels)
            # Predict on the test set
            y_pred = log_reg_sim.predict(X)
            # Calculate the accuracy
            accuracy = accuracy_score(labels, y_pred)
            # Display results
            print("linear regression on pred latent : ", accuracy)


    knn_acc = knn_classify(GPLVM_model.mu_x.cpu().detach().numpy(), ds.labels, rng)
    accknn.append(knn_acc)
    for i, item in enumerate(accknn):
        print(f"accknn Loop {i + 1}: {item}")

    X = GPLVM_model.mu_x.cpu().detach().numpy()
    log_reg_sim = LogisticRegression(max_iter=1000, solver='lbfgs', multi_class='multinomial')
    log_reg_sim.fit(X, labels)
    # Predict on the test set
    y_pred = log_reg_sim.predict(X)
    # Calculate the accuracy
    accuracy = accuracy_score(labels, y_pred)
    acclr.append(accuracy)
    for i, item in enumerate(acclr):
        print(f"acclr Loop {i + 1}: {item}")







