from datasets.datasets import Datasets
import torch.backends.cudnn as cudnn
import random
from torch.utils.data.dataloader import DataLoader
import metrics
from sklearn import metrics as sk_metrics
import torch
from analysis.plot import *
from os.path import join
from tqdm import trange
import time
import os
from utils import model_loader


def train_som(root, train_path, test_root, test_path, norm_type,
              parameters, out_folder, batch_size,
              device, use_cuda, workers,
              evaluate=False, summ_writer=None, coil20_unprocessed=False,
              save=False, load=False, model=None, semi=False,
              labels_sampling='fixed', n_labels=None):

    dataset = Datasets(dataset=train_path, root_folder=root, test_root_folder=test_root, test_dataset=test_path,
                       norm=norm_type, flatten=True, coil20_unprocessed=coil20_unprocessed,
                       labels_sampling=labels_sampling, n_labels=n_labels)

    plots = HParams()
    clustering_errors = []
    accuracies = []
    for param_set in parameters.itertuples():
        # n_max_som = param_set.n_max if n_max is None else n_max
        n_max_som = max(int(len(dataset.train_data) / param_set.n_max), 2)

        if load and model is not None:
            som, som_epochs, manual_seed = model_loader.load_som_model(model, device, semi)
        else:
            som = model_loader.choose_som(dataset.dim_flatten, n_max_som, param_set, device, semi)
            som_epochs = param_set.epochs
            manual_seed = param_set.seed

        random.seed(manual_seed)
        torch.manual_seed(manual_seed)

        if use_cuda:
            torch.cuda.manual_seed_all(manual_seed)
            som.cuda()
            cudnn.benchmark = True

        batch_size = batch_size if batch_size is not None else 32
        train_loader = DataLoader(dataset.train_data, batch_size=batch_size, shuffle=True, num_workers=workers)
        test_loader = DataLoader(dataset.test_data, shuffle=False)

        print(os.path.splitext(train_path)[0] + "_" + str(param_set.Index))

        if use_cuda:
            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)
            start.record()
        else:
            start = time.time()

        for _ in trange(som_epochs):
            for batch_idx, (sample, target) in enumerate(train_loader):
                sample, target = sample.to(device), target.to(device)

                if semi:
                    som(sample, target)
                else:
                    som(sample)

        if use_cuda:
            end.record()

            # Waits for everything to finish running
            torch.cuda.synchronize()
            elapsed_time = start.elapsed_time(end)
            print(elapsed_time)
        else:
            elapsed_time = time.time() - start
            print(elapsed_time)

        if semi:
            predicted_clusters, predicted_labels, true_labels, cluster_result = som.cluster(test_loader)
        else:
            predicted_clusters, true_labels, cluster_result = som.cluster(test_loader)

        filename = os.path.splitext(train_path)[0] if test_path is None else os.path.splitext(test_path)[0]
        filename += "_" + str(param_set.Index)
        som.write_output(join(out_folder, filename + ".results"), cluster_result, elapsed_time=elapsed_time)

        if save:
            torch.save({
                'model_state_dict': som.state_dict(),
                'epochs': som_epochs,
                'input_size': som.input_size,
                'n_max': som.n_max,
                'at': som.at,
                'ds_beta': som.ds_beta,
                'lr': som.lr,
                'eps_ds': som.eps_ds,
                'ld': som.ld,
                'gamma': som.gamma,
                'seed': manual_seed
            }, join(out_folder, filename + '.pth'))

        if evaluate:
            ce = metrics.cluster.predict_to_clustering_error(true_labels, predicted_clusters)
            clustering_errors.append(ce)
            print('CE: {:.3f}'.format(ce))

            if semi:
                acc = sk_metrics.accuracy_score(predicted_labels, true_labels)
                accuracies.append(acc)
                print('ACC: {:.3f}'.format(acc))

        print("")

    if evaluate and summ_writer is not None:
        clustering_errors = np.array(clustering_errors)
        plots.plot_tensorboard_x_y(parameters, 'CE', clustering_errors, summ_writer, train_path.split(".arff")[0])

        if semi:
            accuracies = np.array(accuracies)
            plots.plot_tensorboard_x_y(parameters, 'ACC', accuracies, summ_writer, train_path.split(".arff")[0])




