from datasets.datasets import Datasets
from datasets.custom_dataset import CustomValidDataset
import torch.backends.cudnn as cudnn
import random
from torch.utils.data.dataloader import DataLoader
import metrics
from models.autoencoder_som import AutoEncoderSOM
import torch
import torch.nn as nn
from loss.custom_loss import WeightedMSELoss, ReconstructionLoss, SOMLoss
from os.path import join
from utils import model_loader
from models.init.init_weights import init_w
import wandb
import os
from tqdm import trange
from sklearn import metrics as sk_metrics
import numpy as np
import loss_landscapes
import loss_landscapes.metrics
from analysis.plot import PlotLossLandscape
import copy


def train_model(root, dataset_path, parameters, device, use_cuda, out_folder, debug=False, n_samples=None,
                coil20_unprocessed=False, batch_size=None, epochs=None, alpha=None, beta=None, som_in=None,
                seed=None, lr_decay=None, save=True, load=False, model=None, semi=False, custom_model='autoencoder',
                labels_sampling='fixed', n_labels=None, evaluate=False, balanced=False,
                use_wandb=False, wandb_project=None, landscape=False):

    dataset = Datasets(dataset=dataset_path,
                       root_folder=root, debug=debug, n_samples=n_samples, coil20_unprocessed=coil20_unprocessed,
                       labels_sampling=labels_sampling, n_labels=n_labels, balanced=balanced)

    torch.autograd.set_detect_anomaly(True)

    for cont, param_set in enumerate(parameters.itertuples()):

        if load and model is not None:
            combined_model, epochs, manual_seed, alpha, beta, lr_decay = model_loader.load_autoencodersom_model(model, device)
        else:
            combined_model = load_model(dataset, param_set, semi, device, custom_model=custom_model, som_input=som_in)

        use_batch_size = param_set.batch_size if batch_size is None else batch_size
        use_epochs = param_set.epochs if epochs is None else epochs
        use_manual_seed = param_set.seed if seed is None else seed

        if alpha is None and 'alpha' in param_set._fields:
            use_alpha = param_set.alpha
        else:
            use_alpha = alpha

        if beta is None and 'beta' in param_set._fields:
            use_beta = param_set.beta
        else:
            use_beta = beta

        combined_model_initial = copy.deepcopy(combined_model)
        random.seed(use_manual_seed)
        torch.manual_seed(use_manual_seed)

        if use_cuda:
            torch.cuda.manual_seed_all(use_manual_seed)
            combined_model.cuda()
            cudnn.benchmark = True

        train_loader, valid_loader, test_loader = initialize_loaders(dataset, use_batch_size)
        optimizer = torch.optim.Adam(combined_model.parameters())
        som_lr_optimizer = torch.optim.Adam([combined_model.som.lr0])

        if lr_decay is None and 'lr_decay' in param_set._fields:
            use_lr_decay = param_set.lr_decay
        elif lr_decay is None:
            use_lr_decay = 0.723
        else:
            use_lr_decay = lr_decay

        som_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=som_lr_optimizer, gamma=use_lr_decay)

        mseLoss = nn.MSELoss(reduction='mean')
        weighted_mseLoss = WeightedMSELoss(reduction='mean')

        if use_wandb and wandb_project is not None:
            wandb.init(project=wandb_project,
                       reinit=True,
                       name=dataset_path + "_" + str(param_set.Index),
                       group=os.path.basename(os.path.dirname(out_folder)))

        autoencoder_loss, som_loss, combined_loss = train(combined_model, train_loader, valid_loader, parameters,
                                                          use_epochs, mseLoss, weighted_mseLoss, optimizer,
                                                          som_lr_optimizer,  som_lr_scheduler, use_alpha, use_beta,
                                                          cont, device, semi, custom_model, use_wandb, debug)

        if landscape:
            eval_landscapes(combined_model_initial, combined_model, train_loader,
                            plot=True, save_plots=True, out_folder=out_folder)

        finish(evaluate, combined_model, test_loader, semi, dataset_path, out_folder, param_set.Index, use_wandb)

        if save:
            torch.save({
                'model_state_dict': combined_model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': combined_loss,
                'som_loss': som_loss,
                'autoencoderloss': autoencoder_loss,
                'alpha': use_alpha,
                'beta': use_beta,
                'd_in': dataset.d_in,
                'hw_in': dataset.hw_in,
                'epochs': param_set.epochs,
                'input_size': combined_model.som.input_size,
                'n_max': combined_model.som.n_max,
                'at': combined_model.som.at,
                'ds_beta': combined_model.som.ds_beta,
                'lr0': combined_model.som.lr0.item(),
                'lr': combined_model.som.lr,
                'lr_push': combined_model.som.lr_push,
                'lr_decay': use_lr_decay,
                'eps_ds': combined_model.som.eps_ds,
                'ld': combined_model.som.ld,
                'gamma': combined_model.som.gamma,
                'seed': use_manual_seed,
                'semi': semi
            }, join(out_folder, dataset_path + '_' + str(param_set.Index) + '.pth'))


def train(model, train_loader, valid_loader, parameters, epochs, mseLoss, weighted_mseLoss, optimizer,
          som_lr_optimizer, som_lr_scheduler, alpha, beta, cont, device, semi, custom_model, use_wandb, debug):

    custom_model_recon_loss = 0
    som_weighted_loss = 0
    combined_loss = 0

    t_epochs = trange(epochs, desc="Epochs")
    for _ in t_epochs:
        for batch_idx, (sample, target) in enumerate(train_loader):
            model.train()

            description = "Experiment: {}/{} | Batch idx: {}/{} | Epochs".format(cont + 1,
                                                                                 len(parameters),
                                                                                 batch_idx, len(train_loader))

            t_epochs.set_description(description)
            t_epochs.refresh()

            sample, target = sample.to(device), target.to(device)

            optimizer.zero_grad()

            if semi:
                encoded_features, decoded_features, som_output = model(sample, target)
            else:
                encoded_features, decoded_features, som_output = model(sample)

            samples_high_at, weights_unique_nodes_high_at, relevances, final_dists, final_winners = som_output

            som_weighted_loss = weighted_mseLoss(encoded_features,
                                                 model.som.weights[final_winners], model.som.relevance[final_winners])
            # if len(samples_high_at) == 0:
            #     som_weighted_loss = weighted_mseLoss(encoded_features,
            #                                          model.som.weights[final_winners], model.som.relevance[final_winners])
            # else:
            #     weights_unique_nodes_high_at = weights_unique_nodes_high_at.view(-1, model.som_input_size)
            #     som_weighted_loss = weighted_mseLoss(samples_high_at, weights_unique_nodes_high_at, relevances)

            custom_model_recon_loss = mseLoss(sample.flatten(1), decoded_features)

            # ===================Loss SOM Recon====================
            # som_recon_loss = torch.tensor(0., dtype=torch.float, requires_grad=True, device=device)
            # if len(weights_unique_nodes_high_at) > 0:
            #     decoded_prototypes = model.decoder(model.som.weights[final_winners])
            #     som_recon_loss = mseLoss(sample.flatten(1), decoded_prototypes)
            #
            # recon_loss = custom_model_recon_loss + som_recon_loss

            # ===================backward====================

            # all_dists_loss = torch.sum(torch.mean(final_dists, 0))
            combined_loss =  custom_model_recon_loss + alpha * som_weighted_loss # + beta * all_dists_loss
            eval_metrics, _ = eval(model=model, loader=valid_loader, evaluate=True,
                                   eval_metrics=['nmi', 'pur'], debug=debug)

            if use_wandb:
                node_control = model.som.node_control
                nodes = node_control[node_control == 1].size(0)
                wandb.log({"Nodes": nodes,
                           "{0}-recon-loss".format(custom_model): custom_model_recon_loss,
                           "alpha-som-win-weighted-loss": alpha * som_weighted_loss,
                           # "beta-som-all-weighted-loss": beta * all_dists_loss,
                           "som-win-weighted-loss": som_weighted_loss,
                           # "som-all-weighted-loss": all_dists_loss,
                           "combined-loss": combined_loss,
                           "NMI Valid": eval_metrics['nmi'],
                           "PUR Valid": eval_metrics['pur'],
                           "Decoded Features": [wandb.Image(decoded_features[i].view(1,28,28), 
                                                caption="Target:{}".format(target[i].item())) 
                                                for i in range(len(decoded_features[0:10]))],
                           "Input": [wandb.Image(sample[i].view(1,28,28), 
                                                caption="Target:{}".format(target[i].item())) 
                                                for i in range(len(sample[0:10]))]
                           })

                if semi:
                    classes = model.som.classes[node_control.bool()]
                    n_labeled = (classes != 999).nonzero().size(0)
                    n_unlabeled = (classes == 999).nonzero().size(0)
                    wandb.log({"Labeled-Nodes": n_labeled,
                               "Unlabeled-Nodes": n_unlabeled})

            combined_loss.backward()
            optimizer.step()

        som_lr_optimizer.step()
        som_lr_scheduler.step()
        model.som.lr = som_lr_scheduler.get_last_lr()[0]

    return custom_model_recon_loss, som_weighted_loss, combined_loss


def eval(model, loader, evaluate=False, eval_metrics=None, debug=False):
    eval_dict = {}
    if model is None or loader is None:
        return eval_dict

    model.eval()  
    predicted_clusters, _, true_labels, cluster_result = model.cluster(loader)
    
    if eval_metrics is None:
        return eval_dict, cluster_result

    if evaluate:

        for metric in eval_metrics:
            if metric == 'nmi':
                eval_dict.update({'nmi': metrics.cluster.nmi(true_labels, predicted_clusters)})
            elif metric == 'pur':
                eval_dict.update({'pur': metrics.cluster.purity(true_labels, predicted_clusters)})
            elif metric == 'ari':
                eval_dict.update({'ari': metrics.cluster.ari(true_labels, predicted_clusters)})
            elif metric == 'ce':
                eval_dict.update({'ce': metrics.cluster.predict_to_clustering_error(true_labels, predicted_clusters)})
            elif metric == 'c_acc':
                eval_dict.update({'c_acc': metrics.cluster.acc(true_labels, predicted_clusters)})

        if debug:
            if 'nmi' in eval_dict.keys(): 
                print("Normalized Mutual Information (NMI): %0.3f" % eval_dict['nmi'])
            elif 'pur' in eval_dict.keys():
                print("Purity: %0.3f" % eval_dict['pur'])
            elif 'ari' in eval_dict.keys():
                print("Adjusted Rand Index (ARI): %0.3f" % eval_dict['ari'])
            elif 'ce' in eval_dict.keys():
                print("Clustering Error (CE): %0.3f" % eval_dict['ce'])
            elif 'c_acc' in eval_dict.keys():
                print("Clustering Accuracy: %0.3f" % eval_dict['c_acc'])

    return eval_dict, cluster_result


def initialize_loaders(dataset, batch_size):
    train_loader = DataLoader(dataset.train_data, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(dataset.test_data, shuffle=False)
    custom_valid_dataset = CustomValidDataset(data_loader=test_loader)
    valid_loader = DataLoader(custom_valid_dataset, batch_size=1, shuffle=True)
    
    return train_loader, valid_loader, test_loader


def load_model(dataset, param_set, semi, device, custom_model='autoencoder', som_input=None):
    # Add if's when new custom_models experiemnts became needed
    print("Combined Mode: " + custom_model)

    if semi:
        print("Learning Mode: Semi-Supervised")
    else:
        print("Learning Mode: Unsupervised")

    som_in = param_set.som_in if som_input is None else som_input

    combined_model = AutoEncoderSOM(d_in=dataset.d_in,
                                    hw_in=dataset.hw_in,
                                    som_input=som_in,
                                    n_max=param_set.n_max,
                                    at=param_set.at,
                                    lr0=param_set.lr,
                                    lr=param_set.lr,
                                    lr_push=param_set.lr_push,
                                    ds_beta=param_set.ds_beta,
                                    eps_ds=param_set.eps_ds,
                                    ld=param_set.ld,
                                    gamma=param_set.gamma,
                                    semi=semi,
                                    device=device)

    combined_model.apply(init_w)

    return combined_model


def finish(evaluate, combined_model, test_loader, semi, dataset_path, out_folder, index, use_wandb):
    combined_model.eval()

    predicted_clusters, predicted_labels, true_labels, cluster_result = combined_model.cluster(test_loader)
    filename = dataset_path + "_" + str(index) + ".results"
    combined_model.write_output(join(out_folder, filename), cluster_result)

    if evaluate:
        nmi = metrics.cluster.nmi(true_labels, predicted_clusters)
        ari = metrics.cluster.ari(true_labels, predicted_clusters)
        purity = metrics.cluster.purity(true_labels, predicted_clusters)
        ce = metrics.cluster.predict_to_clustering_error(true_labels, predicted_clusters)
        c_acc = metrics.cluster.acc(true_labels, predicted_clusters)
        acc = 0

        print("Normalized Mutual Information (NMI): %0.3f" % nmi)
        print("Adjusted Rand Index (ARI): %0.3f" % ari)
        print("Purity: %0.3f" % purity)
        print("Clustering Error (CE): %0.3f" % ce)
        print("Clustering Accuracy: %0.3f" % c_acc)

        if predicted_labels is not None:
            acc = sk_metrics.accuracy_score(predicted_labels, true_labels)
            print("Acc: %0.3f" % acc)

        if use_wandb:
            nmax = combined_model.som.n_max
            node_control = combined_model.som.node_control
            nodes = node_control[node_control == 1].size(0)
            wandb.log({"N_Max": nmax,
                       "Nodes": nodes,
                       "NMI": nmi,
                       "ARI": ari,
                       "Purity": purity,
                       "CE": ce,
                       "Clus_Acc": c_acc,
                       "Acc": acc})
            if semi:
                classes = combined_model.som.classes[node_control.bool()]
                n_labeled = (classes != 999).nonzero().size(0)
                n_unlabeled = (classes == 999).nonzero().size(0)
                wandb.log({"N_Labeled": n_labeled,
                           "N_Unlabeled": n_unlabeled})


def eval_landscapes(model_initial, model_current, loader, steps=40, plot=True, save_plots=False, out_folder="./"):
    
    if out_folder.endswith("/"):
        out_folder = out_folder[:-1]

    plotter = PlotLossLandscape()
    combined_model_final = copy.deepcopy(model_current)

    # Plot Reconstruction Loss
    # data that the evaluator will use when evaluating loss
    x, y = iter(loader).__next__()

    # Get decoded features from combined model
    encoded_features, decoded_features, som_output = model_initial(x)
    
    # Calculate reconstruction loss
    reconstruction_loss = ReconstructionLoss(loss_fn=nn.MSELoss(), inputs=decoded_features, target=x)

    # Compute linear interpolation of combined model respective to reconstruction loss
    reconstruction_loss_data = loss_landscapes.linear_interpolation(model_start=model_initial,
                                                                    model_end=combined_model_final,
                                                                    metric=reconstruction_loss, steps=steps,
                                                                    deepcopy_model=True)

    # Calculate random plane to 2d plot respective to reconstruction loss
    reconstruction_loss_data_fin = loss_landscapes.random_plane(model=combined_model_final,
                                                                metric=reconstruction_loss, distance=10, steps=steps,
                                                                normalization='filter', deepcopy_model=True)

    # Plot Reconstruction Loss Landscape
    loss_name = 'Reconstruction Loss'
    plotter.plot_loss_landscape_interpolation(loss_data=reconstruction_loss_data, steps=steps,
                                              ylabel=loss_name, title='Linear Interpolation of ' + loss_name,
                                              plot=plot, save_plots=save_plots,
                                              save_path=join(out_folder, "plot_recon_loss_landscape_interpolation.png"))

    plotter.plot_loss_landscape_2d(loss_data_fin=reconstruction_loss_data_fin,
                                   title=loss_name + ' Contours around Trained Model',
                                   plot=plot, save_plots=save_plots,
                                   save_path=join(out_folder, "plot_recon_loss_landscape_2d.png"))

    plotter.plot_loss_landscape_3d(loss_data_fin=reconstruction_loss_data_fin, projection="3d", steps=steps,
                                   title='Surface Plot of ' + loss_name,
                                   plot=plot, save_plots=save_plots,
                                   save_path=join(out_folder, "plot_recon_loss_landscape_3d.png"))
    
    # Plot SOM Loss
    # Calculate weights unique (target of the som loss)
    decoded_features, som_output = model_initial(x)
    samples_high_at, weights_unique_nodes_high_at, relevances, final_dists, final_winners = som_output
    weights_unique_nodes_high_at = weights_unique_nodes_high_at.view(-1, model_initial.som_input_size)

    # Calculate som loss
    som_loss = SOMLoss(loss_fn=WeightedMSELoss(), som_input_size=model_initial.som_input_size,
                       inputs=x, target=weights_unique_nodes_high_at, weights=relevances)
    
    # Compute linear interpolation of combined model respective to som loss
    som_loss_data = loss_landscapes.linear_interpolation(model_start=model_initial, model_end=combined_model_final, 
                                                    metric=som_loss, steps=steps, deepcopy_model=True)

    # Calculate random plane to 2d plot respective to reconstruction loss
    som_loss_data_fin = loss_landscapes.random_plane(model=combined_model_final, metric=som_loss, distance=10, 
                                steps=steps, normalization='filter', deepcopy_model=True)

    # Plot Som Loss Landscape
    loss_name = 'SOM Loss'
    plotter.plot_loss_landscape_interpolation(loss_data=som_loss_data, steps=steps,
                                              ylabel=loss_name, title='Linear Interpolation of ' + loss_name,
                                              plot=plot, save_plots=save_plots,
                                              save_path=join(out_folder, "plot_som_loss_landscape_interpolation.png"))
    plotter.plot_loss_landscape_2d(loss_data_fin=som_loss_data_fin,
                                   title=loss_name + ' Contours around Trained Model',
                                   plot=plot, save_plots=save_plots,
                                   save_path=join(out_folder, "plot_som_loss_landscape_2d.png"))

    plotter.plot_loss_landscape_3d(loss_data_fin=som_loss_data_fin, projection="3d", steps=steps,
                                   title='Surface Plot of ' + loss_name, plot=plot, save_plots=save_plots,
                                   save_path=join(out_folder, "plot_som_loss_landscape_3d.png"))
