import os
import json
import logging
from time import perf_counter

import numpy as np
from numpy.matlib import repmat
from scipy.io import savemat
import scipy.io as sio
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

from sklearn.gaussian_process.kernels import PairwiseKernel
from wrench.dataset import load_dataset, BaseDataset, get_dataset_type, TorchDataset
from wrench.dataset import load_image_dataset
from wrench._logging import LoggingHandler
from wrench.labelmodel import Fable
from sklearn.calibration import calibration_curve
from sklearn.metrics import brier_score_loss
from sklearn.metrics import log_loss

#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])
logger = logging.getLogger(__name__)

def run_Fable(
        dataset_prefix,
        dataset_name=None,
        use_test=False,
        save_path=None,
        n_runs=10,
        ):

    #### Load dataset
    device = torch.device('cuda:0')

    # load dataset with no features for the labels
    # nf means no features
    # dataset_nf_path = os.path.join(dataset_prefix, dataset_name + '.mat')
    # data_nf = sio.loadmat(dataset_nf_path)
    # train_data_nf = [data_nf['train_pred'], data_nf['train_labels']]
    # n_classes = np.max(data_nf['train_labels']) + 1

    # if use_test:
    #     test_data_nf = [data_nf['test_pred'], data_nf['test_labels']]

    # now load datast with features
    dataset_path = os.path.join(dataset_prefix, 'datasets_with_features')
    if dataset_name in ['commercial', 'tennis']:
        # don't extract features for numeric datasets -- use original features
        train_data, valid_data, test_data = load_dataset(
            dataset_path,
            dataset_name,
            extract_feature=True
        )
    else:
        train_data, valid_data, test_data = load_dataset(
            dataset_path,
            dataset_name,
            extract_feature=True,
            extract_fn='bert',
            model_name='roberta-base',  # roberta-base, roberta; bert-base-uncased, bert
            cache_name='roberta',
            normalize=True,
            device=device
        )

    if merge:
        train_data = concat(train_data, valid_data, dataset_name)
        train_data = concat(train_data, test_data, dataset_name)

    train_data = train_data.get_covered_subset()
    n_classes = np.max(train_data.labels) + 1

    #### Run label model: Fable

    for run_no in range(n_runs):
        if n_runs > 1:
            logger.info('--------Run Number %d--------', run_no + 1)

        label_model = Fable(
            num_groups=3, # hard coded from paper suggestion
            inference_iter=10,
            # num_groups * num_correct = 3 * 1000 respectively from paper
            a_v=train_data.n_class * 3 * 1000,
            b_v=1,
            empirical_prior=True,
            kernel_function=PairwiseKernel('cosine'),
            desired_rank=50,
            device=device
        )

        if len(train_data.ids) > 20000:
            # results for averaging
            res_01 = []
            res_log = []
            res_brier = []
            if n_classes == 2:
                res_f1 = []
            res_time_elapsed = []

            for batch in tqdm(DataLoader(TorchDataset(train_data), shuffle=True, batch_size=10000)):
                # create dataset
                batch_set = create_dataset(batch, train_data, dataset_name)
                # fit to that batch and time
                start_time = perf_counter()
                label_model.fit(batch_set)
                end_time = perf_counter()

                # evaluate and save
                res_01.append(label_model.test(batch_set, 'acc', batch_learning=True))
                if n_classes == 2:
                    res_f1.append(label_model.test(batch_set, 'f1_binary', batch_learning=True))
                # get prediction to evaluate other losses
                Y_p_train_batch = label_model.predict_proba(batch_set)
                true_labels_train = np.squeeze(batch_set.labels)
                res_log.append(log_loss(true_labels_train,Y_p_train_batch))
                res_brier.append(multi_brier(true_labels_train, Y_p_train_batch))
                res_time_elapsed.append(end_time - start_time)

            acc_train = np.mean(res_01)
            logloss_train = np.mean(res_log)
            brier_score_train = np.mean(res_brier)
            if n_classes == 2:
                f1_score_train = np.mean(res_f1)
            elapsed_time = np.sum(res_time_elapsed)

        else:
            start_time = perf_counter()
            label_model.fit(train_data)
            end_time = perf_counter()
            elapsed_time = end_time - start_time
            acc_train = label_model.test(train_data, 'acc')
            if n_classes == 2:
                f1_score_train = label_model.test(train_data, 'f1_binary')
            Y_p_train = label_model.predict_proba(train_data)
            true_labels_train = np.squeeze(train_data.labels)
            logloss_train = log_loss(true_labels_train,Y_p_train)
            brier_score_train = multi_brier(true_labels_train, Y_p_train)

        # only run calibration code if n_classes is 2
        # if n_classes == 2:
        #     prob_true_train, prob_pred_train = calibration_curve(
        #             np.squeeze(train_data[1]),
        #             np.clip(Y_p_train[:, 1], 0, 1),
        #             n_bins=10)
        #     if use_test:
        #         prob_true_test, prob_pred_test = calibration_curve(
        #                 np.squeeze(test_data[1]),
        #                 np.clip(Y_p_test[:, 1], 0, 1),
        #                 n_bins=10)

        if run_no == 0:
            mdic = {
                    # "pred_train": [],
                    "log_loss_train": [],
                    # "true_labels_train": true_labels_train,
                    "brier_score_train": [],
                    "acc_train": [],
                    "err_train": [],
                    "fit_elapsed_time": [],
                    }
            if n_classes == 2:
                mdic['f1_score_train'] = []
                # mdic["x_calibration_train"] = []
                # mdic["y_calibration_train"] = []

        # don't save predictions because if we are using batches this might
        # blow up the file size
        # mdic["pred_train"].append(Y_p_train)
        mdic["log_loss_train"].append(logloss_train)
        mdic["brier_score_train"].append(brier_score_train)
        if n_classes == 2:
            mdic["f1_score_train"].append(f1_score_train)
        mdic["acc_train"].append(acc_train)
        mdic["err_train"].append(1 - acc_train)
        mdic["num_rule"] = len(train_data.weak_labels[0])
        mdic["fit_elapsed_time"].append(elapsed_time)
        # if n_classes == 2:
        #     mdic["x_calibration_train"].append(prob_pred_train)
        #     mdic["y_calibration_train"].append(prob_true_train)

        # if use_test:
        #     brier_score_test = multi_brier(true_labels_test, Y_p_test)
        #     logloss_test = log_loss(true_labels_test, Y_p_test)
        #     acc_test = label_model.test(test_data, 'acc')
        #     if n_classes == 2:
        #         f1_score_test = label_model.test(test_data, 'f1_binary')
        #     mdic_test = {
        #                 "pred_test": [],
        #                 "true_labels_test": true_labels_test,
        #                 "log_loss_test": [],
        #                 "brier_score_test": [],
        #                 "acc_test": [],
        #                 "err_test": [],
        #                 }
        #     if n_classes == 2:
        #         mdic_test["f1_score_test"] = []
        #         # mdic_test["x_calibration_test"] = []
        #         # mdic_test["y_calibration_test"] = []

        #         mdic.update(mdic_test)

        #     mdic["pred_test"].append(Y_p_test)
        #     mdic["log_loss_test"].append(logloss_test)
        #     mdic["brier_score_test"].append(brier_score_test)
        #     if n_classes == 2:
        #         mdic["f1_score_test"].append(f1_score_test)
        #     mdic["acc_test"].append(acc_test)
        #     mdic["err_test"].append(1 - acc_test)
        #     # if n_classes == 2:
        #     #     mdic_test["x_calibration_test"].append(prob_pred_test)
        #     #     mdic_test["y_calibration_test"].append(prob_true_test)

        ### report results
        logger.info('----------------Results----------------')
        logger.info('time to fit: %.1f seconds', elapsed_time)
        logger.info('train acc (train err): %.4f (%.4f)',
                acc_train, 1 - acc_train)
        logger.info('train log loss: %.4f', logloss_train)
        logger.info('train brier score: %.4f', brier_score_train)
        if n_classes == 2:
            logger.info('train f1 score: %.4f', f1_score_train)
        # if use_test:
        #     logger.info('test acc (test err): %.4f (%.4f)',
        #             acc_test, 1 - acc_test)
        #     logger.info('test log loss: %.4f', logloss_test)
        #     logger.info('test brier score: %.4f', brier_score_test)
        #     if n_classes == 2:
        #         logger.info('test f1 score: %.4f', f1_score_test)

    # if number of runs is >1, report and store mean results and standard
    # deviations
    if n_runs > 1:
        mdic["log_loss_train_mean"]     = np.mean(mdic["log_loss_train"])
        mdic["brier_score_train_mean"]  = np.mean(mdic["brier_score_train"])
        if n_classes == 2:
            mdic["f1_score_train_mean"]  = np.mean(mdic["f1_score_train"])
        mdic["acc_train_mean"]          = np.mean(mdic["acc_train"])
        mdic["err_train_mean"]          = np.mean(mdic["err_train"])
        mdic["fit_elapsed_time_mean"]   = np.mean(mdic["fit_elapsed_time"])

        mdic["fit_elapsed_time_std"]   = np.std(mdic["fit_elapsed_time"])
        mdic["log_loss_train_std"]     = np.std(mdic["log_loss_train"])
        mdic["brier_score_train_std"]  = np.std(mdic["brier_score_train"])
        if n_classes == 2:
            mdic["f1_score_train_std"]  = np.std(mdic["f1_score_train"])
        mdic["acc_train_std"]          = np.std(mdic["acc_train"])
        mdic["err_train_std"]          = np.std(mdic["err_train"])

        # if use_test:
        #     mdic["log_loss_test_mean"]    = np.mean(mdic["log_loss_test"])
        #     mdic["brier_score_test_mean"] = np.mean(mdic["brier_score_test"])
        #     if n_classes == 2:
        #         mdic["f1_score_test_mean"]  = np.mean(mdic["f1_score_test"])
        #     mdic["acc_test_mean"]         = np.mean(mdic["acc_test"])
        #     mdic["err_test_mean"]         = np.mean(mdic["err_test"])

        #     mdic["log_loss_test_std"]    = np.std(mdic["log_loss_test"])
        #     mdic["brier_score_test_std"] = np.std(mdic["brier_score_test"])
        #     if n_classes == 2:
        #         mdic["f1_score_test_std"]  = np.std(mdic["f1_score_test"])
        #     mdic["acc_test_std"]         = np.std(mdic["acc_test"])
        #     mdic["err_test_std"]         = np.std(mdic["err_test"])

        logger.info('================Aggregated Results================')
        logger.info('Total number of runs: %d', n_runs)
        logger.info('Average time to fit: %.1f seconds (std: %.1f)',
                mdic['fit_elapsed_time_mean'], mdic['fit_elapsed_time_std'])
        logger.info('train mean acc +- std (mean train err):'
                ' %.4f +- %.4f (%.4f)', mdic['acc_train_mean'],
                mdic['acc_train_std'], mdic['err_train_mean'])
        logger.info('train mean log loss +- std: %.4f +- %.4f',
                mdic['log_loss_train_mean'], mdic['log_loss_train_std'])
        logger.info('train mean brier score +- std: %.4f +- %.4f',
                mdic['brier_score_train_mean'], mdic['brier_score_train_std'])
        if n_classes == 2:
            logger.info('train mean f1 score +- std: %.4f +- %.4f',
                    mdic['f1_score_train_mean'], mdic['f1_score_train_std'])

        # if use_test:
        #     logger.info('test mean acc +- std (mean test err):'
        #             ' %.4f +- %.4f (%.4f)', mdic['acc_test_mean'],
        #             mdic['acc_test_std'], mdic['err_test_mean'])
        #     logger.info('test mean log loss +- std: %.4f +- %.4f',
        #             mdic['log_loss_test_mean'], mdic['log_loss_test_std'])
        #     logger.info('test mean brier score +- std: %.4f +- %.4f',
        #             mdic['brier_score_test_mean'], mdic['brier_score_test_std'])
        #     if n_classes == 2:
        #         logger.info('test mean f1 score +- std: %.4f +- %.4f',
        #                 mdic['f1_score_test_mean'], mdic['f1_score_test_std'])

    result_filename = get_result_filename(dataset_name)

    savemat(os.path.join(save_path, result_filename), mdic)

    return mdic

def concat(d1: BaseDataset, d2: BaseDataset, name: str) -> BaseDataset:
    dataset = get_dataset_type(name)()
    dataset.ids = d1.ids + d2.ids
    dataset.labels = d1.labels + d2.labels
    dataset.examples = d1.examples + d2.examples
    dataset.weak_labels = d1.weak_labels + d2.weak_labels
    dataset.n_class = d1.n_class
    dataset.n_lf = d1.n_lf
    dataset.features = np.vstack([d1.features, d2.features])

    return dataset

def get_result_filename(dataset_name):
    filename = "Fable_"\
            + dataset_name + ".mat"

    return filename

def multi_brier(labels, pred_probs):
    """
    multiclass brier score
    assumes labels are a 1D vector with values in {0, 1, n_class - 1}
    """
    n_class = int(np.max(labels) + 1)
    labels = (np.arange(n_class) == labels[..., None]).astype(int)
    sq_diff = np.square(labels - pred_probs)
    datapoint_loss =  np.sum(sq_diff, axis=1)

    return np.mean(datapoint_loss)

def create_dataset(batch: dict, dataset: BaseDataset, name: str) -> BaseDataset:
    new_set = get_dataset_type(name)()

    new_set.ids = batch['ids'].tolist()
    new_set.labels = batch['labels'].tolist()
    new_set.examples = batch['data']
    new_set.weak_labels = batch['weak_labels'].tolist()
    new_set.n_class = dataset.n_class
    new_set.n_lf = dataset.n_lf
    new_set.features = batch['features']

    return new_set

# pylint: disable=C0103
if __name__ == '__main__':
    # create results folder if it doesn't exist
    results_folder_path = './results'
    if not os.path.exists(results_folder_path):
        os.makedirs(results_folder_path)

    # path for config jsons
    dataset_prefix = './datasets/'

    # whether to put the train/valid/test data together
    merge=True
    merged_txt = '_merged' if merge else ''

    # datasets = []
    # wrench datasets
    datasets = ['imdb', 'youtube', 'sms', 'cdr', 'yelp', 'commercial',\
            'tennis', 'trec', 'semeval', 'chemprot', 'agnews']

    for dataset in datasets:
        # make result folder if it doesn't exist
        dataset_result_path = os.path.join(results_folder_path, dataset + merged_txt)
        if not os.path.exists(dataset_result_path):
            os.makedirs(dataset_result_path)
        # make folder for Fable specifically
        method_result_path = os.path.join(dataset_result_path, 'Fable')
        if not os.path.exists(method_result_path):
            os.makedirs(method_result_path)


        for handler in logger.handlers[:]:
            logger.removeHandler(handler)
        formatter = logging.Formatter('%(asctime)s - %(message)s',
                '%Y-%m-%d %H:%M:%S')

        # do some formatting for log name
        log_filename = get_result_filename(dataset)[:-4] + '.log'
        log_filename_full = os.path.join(method_result_path,
                log_filename)
        file_handler = logging.FileHandler(log_filename_full, 'w')
        file_handler.setFormatter(formatter)
        logger.addHandler(file_handler)
        # log all the run parameters
        logger.info('============Running Fable============')
        logger.info('dataset: %s', dataset)

        run_Fable(
                dataset_prefix,
                dataset_name = dataset,
                save_path=method_result_path,
                )
