import os
import json
import logging
from time import perf_counter

import torch
import numpy as np
from numpy.matlib import repmat
from scipy.io import savemat
import scipy.io as sio
import matplotlib.pyplot as plt

from wrench.dataset import load_image_dataset
from wrench.dataset import load_dataset, BaseDataset, get_dataset_type, TorchDataset
from wrench._logging import LoggingHandler
from wrench.classification import WeaSEL
from sklearn.calibration import calibration_curve
from sklearn.metrics import brier_score_loss
from sklearn.metrics import log_loss

#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])
logger = logging.getLogger(__name__)

def run_WeaSEL(
        dataset_prefix,
        dataset_name=None,
        use_test=False,
        save_path=None,
        n_runs=10,
        ):

    #### Load dataset
    # dataset_path = os.path.join(dataset_prefix, dataset_name + '.mat')
    # data = sio.loadmat(dataset_path)
    # train_data = [data['train_pred'], data['train_labels']]
    # n_classes = np.max(data['train_labels']) + 1

    # if use_test:
    #     test_data = [data['test_pred'], data['test_labels']]
    # now load datast with features

    device = torch.device('cuda:0')

    dataset_path = os.path.join(dataset_prefix, 'datasets_with_features')
    if dataset_name in ['commercial', 'tennis']:
        # don't extract features for numeric datasets -- use original features
        train_data, valid_data, test_data = load_dataset(
            dataset_path,
            dataset_name,
            extract_feature=True
        )
    else:
        train_data, valid_data, test_data = load_dataset(
            dataset_path,
            dataset_name,
            extract_feature=True,
            extract_fn='bert',
            model_name='roberta-base',  # roberta-base, roberta; bert-base-uncased, bert
            cache_name='roberta',
            normalize=True,
            device=device
        )

    if merge:
        train_data = concat(train_data, valid_data, dataset_name)
        train_data = concat(train_data, test_data, dataset_name)

    train_data = train_data.get_covered_subset()
    n_classes = np.max(train_data.labels) + 1

    #### Run label model: WeaSEL

    for run_no in range(n_runs):
        if n_runs > 1:
            logger.info('--------Run Number %d--------', run_no + 1)

        label_model = WeaSEL()
        start_time = perf_counter()
        label_model.fit(train_data)
        end_time = perf_counter()

        elapsed_time = end_time - start_time

        ### make predictions
        Y_p_train = label_model.predict_proba(train_data)
        # pred_train = np.argmax(Y_p_train, axis=1)
        true_labels_train = np.squeeze(train_data.labels)
        # if use_test:
        #     Y_p_test = label_model.predict_proba(test_data)
        #     # pred_test = np.argmax(Y_p_test, axis=1)
        #     true_labels_test = np.squeeze(test_data[1])

        ### compute losses
        brier_score_train = multi_brier(true_labels_train, Y_p_train)
        logloss_train = log_loss(true_labels_train,Y_p_train)
        acc_train = label_model.test(train_data, 'acc')
        if n_classes == 2:
            f1_score_train = label_model.test(train_data, 'f1_binary')

        # only run calibration code if n_classes is 2
        # if n_classes == 2:
        #     prob_true_train, prob_pred_train = calibration_curve(
        #             np.squeeze(train_data[1]),
        #             np.clip(Y_p_train[:, 1], 0, 1),
        #             n_bins=10)
        #     if use_test:
        #         prob_true_test, prob_pred_test = calibration_curve(
        #                 np.squeeze(test_data[1]),
        #                 np.clip(Y_p_test[:, 1], 0, 1),
        #                 n_bins=10)

        if run_no == 0:
            mdic = {
                    # "pred_train": [],
                    "log_loss_train": [],
                    # "true_labels_train": true_labels_train,
                    "brier_score_train": [],
                    "acc_train": [],
                    "err_train": [],
                    "fit_elapsed_time": [],
                    }
            if n_classes == 2:
                mdic['f1_score_train'] = []
                # mdic["x_calibration_train"] = []
                # mdic["y_calibration_train"] = []

        # mdic["pred_train"].append(Y_p_train)
        mdic["log_loss_train"].append(logloss_train)
        mdic["brier_score_train"].append(brier_score_train)
        if n_classes == 2:
            mdic["f1_score_train"].append(f1_score_train)
        mdic["acc_train"].append(acc_train)
        mdic["err_train"].append(1 - acc_train)
        mdic["num_rule"] = len(train_data.weak_labels[0])
        mdic["fit_elapsed_time"].append(elapsed_time)
        # if n_classes == 2:
        #     mdic["x_calibration_train"].append(prob_pred_train)
        #     mdic["y_calibration_train"].append(prob_true_train)

        # if use_test:
        #     brier_score_test = multi_brier(true_labels_test, Y_p_test)
        #     logloss_test = log_loss(true_labels_test, Y_p_test)
        #     acc_test = label_model.test(test_data, 'acc')
        #     if n_classes == 2:
        #         f1_score_test = label_model.test(test_data, 'f1_binary')
        #     mdic_test = {
        #                 "pred_test": [],
        #                 "true_labels_test": true_labels_test,
        #                 "log_loss_test": [],
        #                 "brier_score_test": [],
        #                 "acc_test": [],
        #                 "err_test": [],
        #                 }
        #     if n_classes == 2:
        #         mdic_test["f1_score_test"] = []
        #         mdic_test["x_calibration_test"] = []
        #         mdic_test["y_calibration_test"] = []

        #         mdic.update(mdic_test)

        #     mdic["pred_test"].append(Y_p_test)
        #     mdic["log_loss_test"].append(logloss_test)
        #     mdic["brier_score_test"].append(brier_score_test)
        #     if n_classes == 2:
        #         mdic["f1_score_test"].append(f1_score_test)
        #     mdic["acc_test"].append(acc_test)
        #     mdic["err_test"].append(1 - acc_test)
        #     if n_classes == 2:
        #         mdic_test["x_calibration_test"].append(prob_pred_test)
        #         mdic_test["y_calibration_test"].append(prob_true_test)

        ### report results
        logger.info('----------------Results----------------')
        logger.info('time to fit: %.1f seconds', elapsed_time)
        logger.info('train acc (train err): %.4f (%.4f)',
                acc_train, 1 - acc_train)
        logger.info('train log loss: %.4f', logloss_train)
        logger.info('train brier score: %.4f', brier_score_train)
        if n_classes == 2:
            logger.info('train f1 score: %.4f', f1_score_train)
        # if use_test:
        #     logger.info('test acc (test err): %.4f (%.4f)',
        #             acc_test, 1 - acc_test)
        #     logger.info('test log loss: %.4f', logloss_test)
        #     logger.info('test brier score: %.4f', brier_score_test)
        #     if n_classes == 2:
        #         logger.info('test f1 score: %.4f', f1_score_test)

    # if number of runs is >1, report and store mean results and standard
    # deviations
    if n_runs > 1:
        mdic["log_loss_train_mean"]     = np.mean(mdic["log_loss_train"])
        mdic["brier_score_train_mean"]  = np.mean(mdic["brier_score_train"])
        if n_classes == 2:
            mdic["f1_score_train_mean"]  = np.mean(mdic["f1_score_train"])
        mdic["acc_train_mean"]          = np.mean(mdic["acc_train"])
        mdic["err_train_mean"]          = np.mean(mdic["err_train"])
        mdic["fit_elapsed_time_mean"]   = np.mean(mdic["fit_elapsed_time"])

        mdic["fit_elapsed_time_std"]   = np.std(mdic["fit_elapsed_time"])
        mdic["log_loss_train_std"]     = np.std(mdic["log_loss_train"])
        mdic["brier_score_train_std"]  = np.std(mdic["brier_score_train"])
        if n_classes == 2:
            mdic["f1_score_train_std"]  = np.std(mdic["f1_score_train"])
        mdic["acc_train_std"]          = np.std(mdic["acc_train"])
        mdic["err_train_std"]          = np.std(mdic["err_train"])

        # if use_test:
        #     mdic["log_loss_test_mean"]    = np.mean(mdic["log_loss_test"])
        #     mdic["brier_score_test_mean"] = np.mean(mdic["brier_score_test"])
        #     if n_classes == 2:
        #         mdic["f1_score_test_mean"]  = np.mean(mdic["f1_score_test"])
        #     mdic["acc_test_mean"]         = np.mean(mdic["acc_test"])
        #     mdic["err_test_mean"]         = np.mean(mdic["err_test"])

        #     mdic["log_loss_test_std"]    = np.std(mdic["log_loss_test"])
        #     mdic["brier_score_test_std"] = np.std(mdic["brier_score_test"])
        #     if n_classes == 2:
        #         mdic["f1_score_test_std"]  = np.std(mdic["f1_score_test"])
        #     mdic["acc_test_std"]         = np.std(mdic["acc_test"])
        #     mdic["err_test_std"]         = np.std(mdic["err_test"])

        logger.info('================Aggregated Results================')
        logger.info('Total number of runs: %d', n_runs)
        logger.info('Average time to fit: %.1f seconds (std: %.1f)',
                mdic['fit_elapsed_time_mean'], mdic['fit_elapsed_time_std'])
        logger.info('train mean acc +- std (mean train err):'
                ' %.4f +- %.4f (%.4f)', mdic['acc_train_mean'],
                mdic['acc_train_std'], mdic['err_train_mean'])
        logger.info('train mean log loss +- std: %.4f +- %.4f',
                mdic['log_loss_train_mean'], mdic['log_loss_train_std'])
        logger.info('train mean brier score +- std: %.4f +- %.4f',
                mdic['brier_score_train_mean'], mdic['brier_score_train_std'])
        if n_classes == 2:
            logger.info('train mean f1 score +- std: %.4f +- %.4f',
                    mdic['f1_score_train_mean'], mdic['f1_score_train_std'])

        # if use_test:
        #     logger.info('test mean acc +- std (mean test err):'
        #             ' %.4f +- %.4f (%.4f)', mdic['acc_test_mean'],
        #             mdic['acc_test_std'], mdic['err_test_mean'])
        #     logger.info('test mean log loss +- std: %.4f +- %.4f',
        #             mdic['log_loss_test_mean'], mdic['log_loss_test_std'])
        #     logger.info('test mean brier score +- std: %.4f +- %.4f',
        #             mdic['brier_score_test_mean'], mdic['brier_score_test_std'])
        #     if n_classes == 2:
        #         logger.info('test mean f1 score +- std: %.4f +- %.4f',
        #                 mdic['f1_score_test_mean'], mdic['f1_score_test_std'])

    result_filename = get_result_filename(dataset_name)

    savemat(os.path.join(save_path, result_filename), mdic)

    return mdic

def concat(d1: BaseDataset, d2: BaseDataset, name: str) -> BaseDataset:
    dataset = get_dataset_type(name)()
    dataset.ids = d1.ids + d2.ids
    dataset.labels = d1.labels + d2.labels
    dataset.examples = d1.examples + d2.examples
    dataset.weak_labels = d1.weak_labels + d2.weak_labels
    dataset.n_class = d1.n_class
    dataset.n_lf = d1.n_lf
    dataset.features = np.vstack([d1.features, d2.features])

    return dataset

def get_result_filename(dataset_name):
    filename = "WeaSEL_"\
            + dataset_name + ".mat"

    return filename

def multi_brier(labels, pred_probs):
    """
    multiclass brier score
    assumes labels are a 1D vector with values in {0, 1, n_class - 1}
    """
    n_class = int(np.max(labels) + 1)
    labels = (np.arange(n_class) == labels[..., None]).astype(int)
    sq_diff = np.square(labels - pred_probs)
    datapoint_loss =  np.sum(sq_diff, axis=1)

    return np.mean(datapoint_loss)

# pylint: disable=C0103
if __name__ == '__main__':
    # create results folder if it doesn't exist
    results_folder_path = './results'
    if not os.path.exists(results_folder_path):
        os.makedirs(results_folder_path)

    # path for config jsons
    dataset_prefix = './datasets/'

    # datasets = []
    # wrench datasets
    datasets = ['imdb', 'youtube', 'sms', 'cdr', 'yelp', 'commercial',\
        'tennis', 'trec', 'semeval', 'chemprot', 'agnews']


    # whether to put the train/valid/test data together
    merge=True
    merged_txt = '_merged' if merge else ''

    for dataset in datasets:
        # make result folder if it doesn't exist
        dataset_result_path = os.path.join(results_folder_path, dataset + merged_txt)
        if not os.path.exists(dataset_result_path):
            os.makedirs(dataset_result_path)
        # make folder for WeaSEL specifically
        method_result_path = os.path.join(dataset_result_path, 'WeaSEL')
        if not os.path.exists(method_result_path):
            os.makedirs(method_result_path)


        for handler in logger.handlers[:]:
            logger.removeHandler(handler)
        formatter = logging.Formatter('%(asctime)s - %(message)s',
                '%Y-%m-%d %H:%M:%S')

        # do some formatting for log name
        log_filename = get_result_filename(dataset)[:-4] + '.log'
        log_filename_full = os.path.join(method_result_path,
                log_filename)
        file_handler = logging.FileHandler(log_filename_full, 'w')
        file_handler.setFormatter(formatter)
        logger.addHandler(file_handler)
        # log all the run parameters
        logger.info('============Running WeaSEL============')
        logger.info('dataset: %s', dataset)

        run_WeaSEL(
                dataset_prefix,
                dataset_name = dataset,
                save_path=method_result_path,
                )
