"""
Filtering and dataset mapping methods based on training dynamics.
By default, this module reads training dynamics from a given trained model and
computes the metrics---confidence, variability, correctness,
as well as baseline metrics of forgetfulness and threshold closeness
for each instance in the training data.
If specified, data maps can be plotted with respect to confidence and variability.
Moreover, datasets can be filtered with respect any of the other metrics.


# Plot Data Maps
To plot data maps for a trained $MODEL (e.g. RoBERTa-Large) on a given $TASK (e.g. SNLI, MNLI, QNLI or WINOGRANDE):

python -m cartography.selection.train_dy_filtering \
    --plot \
    --task_name $TASK \
    --model_dir $PATH_TO_MODEL_OUTPUT_DIR_WITH_TRAINING_DYNAMICS \
    --model $MODEL_NAME

"""
import argparse
import json
import logging
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import seaborn as sns
import torch
import tqdm

from collections import defaultdict
from typing import List

import pickle

logging.basicConfig(
  format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", level=logging.INFO
)
logger = logging.getLogger(__name__)

def compute_my_dynamics():
    """
    Create a dummy df data frame 
    """
    ##################################
    # 1. Load pkl that stores the prediction scores 
    # And the pkl that stores the meta-dataset-structure
    ##################################
    with open(os.path.join( args.output_dir, 'model_train_cartography_dump.pkl'), "rb") as pkl_file:
        load_data = pickle.load( pkl_file )
        global_train_dataset_cartography_list = load_data['global_train_dataset_cartography_list']
        imageID_to_group = load_data['imageID_to_train_group']
        my_dataset_samples = load_data['my_dataset_samples']

    """
    Given the training dynamics (logits for each training instance across epochs), compute metrics
    based on it, for data map coorodinates.
    Computed metrics are: confidence, variability, correctness, forgetfulness, threshold_closeness---
    the last two being baselines from prior work
    (Example Forgetting: https://arxiv.org/abs/1812.05159 and
    Active Bias: https://arxiv.org/abs/1704.07433 respectively).
    Returns:
    - DataFrame with these metrics.
    - DataFrame with more typical training evaluation metrics, such as accuracy / loss.
    """
    confidence_ = {}
    variability_ = {}
    threshold_closeness_ = {}
    correctness_ = {}
    forgetfulness_ = {}

    # Functions to be applied to the data.
    variability_func = lambda conf: np.std(conf)
    if args.include_ci:  # Based on prior work on active bias (https://arxiv.org/abs/1704.07433)
        variability_func = lambda conf: np.sqrt(np.var(conf) + np.var(conf) * np.var(conf) / (len(conf)-1))
    threshold_closeness_func = lambda conf: conf * (1 - conf)

    loss = torch.nn.CrossEntropyLoss()

    num_tot_epochs = len(global_train_dataset_cartography_list)
    if args.burn_out < num_tot_epochs:
        logger.info(f"Computing training dynamics. Burning out at {args.burn_out} of {num_tot_epochs}. ")
        global_train_dataset_cartography_list = global_train_dataset_cartography_list[:args.burn_out]
        num_tot_epochs = len(global_train_dataset_cartography_list)
        logger.info(f"Cut down to {num_tot_epochs} epochs. ")
    else:
        logger.info(f"Computing training dynamics across {num_tot_epochs} epochs")
    logger.info("Metrics computed: confidence, variability, correctness, forgetfulness, threshold_closeness")


    ##################################
    # Call the alraedy written function here 
    ##################################
    ##################################
    # First calculate for each datapoint
    ##################################
    cargography_confidence_array = [] # for each data point 
    cargography_variability_array = [] # for each data point 
    correctness_array = []
    group_to_cargography_confidence = defaultdict(list)
    group_to_cargography_variability = defaultdict(list)

    for idx, sample in enumerate(my_dataset_samples):
        image_path, target = sample
        imageID = image_path.split('/')[-1].split('.')[0] # image_path = IMAGE_DATA_FOLDER + imageID + '.jpg'
        tmp_score_array = []
        for epoch_idx in range(len(global_train_dataset_cartography_list)):
            assert target==0 or target==1
            if target==1:
                tmp_score_array.append(global_train_dataset_cartography_list[epoch_idx][idx])
            else:
                tmp_score_array.append(1 - global_train_dataset_cartography_list[epoch_idx][idx])
        
        cargography_confidence_array.append( np.mean(tmp_score_array) )
        cargography_variability_array.append( np.std(tmp_score_array) )
        correctness_array.append( np.sum(np.array(tmp_score_array)>0.5) )

        for groups_local in imageID_to_group[imageID]:
            group_to_cargography_confidence[groups_local].append( np.mean(tmp_score_array) )
            group_to_cargography_variability[groups_local].append( np.std(tmp_score_array) )

    # print('correctness_array', correctness_array)
    column_names = [
        # 'guid',
                    'index',
                    # 'threshold_closeness',
                    'confidence',
                    'variability',
                    'correctness',
                    # 'forgetfulness',
                    ]
    df = pd.DataFrame([[
        # guid,
                        i,
                        # threshold_closeness_[guid],
                        cargography_confidence_array[i],
                        cargography_variability_array[i],
                        correctness_array[i], # correctness_[guid],
                        # forgetfulness_[guid],
                        ] for i in range(len(cargography_variability_array))], 
                        columns=column_names)

    return df

def plot_data_map(dataframe: pd.DataFrame,
                  plot_dir: os.path,
                  hue_metric: str = 'correct.',
                  title: str = '',
                  model: str = 'ResNet-18',
                  show_hist: bool = False,
                  max_instances_to_plot = 55000):
    # Set style.
    sns.set(style='whitegrid', font_scale=1.6, font='Georgia', context='paper')
    logger.info(f"Plotting figure for {title} using the {model} model ...")

    # Subsample data to plot, so the plot is not too busy.
    dataframe = dataframe.sample(n=max_instances_to_plot if dataframe.shape[0] > max_instances_to_plot else len(dataframe))

    # Normalize correctness to a value between 0 and 1.
    dataframe = dataframe.assign(corr_frac = lambda d: d.correctness / d.correctness.max())
    dataframe['correct.'] = [float(f"{x:.1f}") for x in dataframe['corr_frac']]

    main_metric = 'variability'
    other_metric = 'confidence'

    hue = hue_metric
    num_hues = len(dataframe[hue].unique().tolist())
    # style = hue_metric if num_hues < 8 else None
    style = hue_metric

    if not show_hist:
        fig, ax0 = plt.subplots(1, 1, figsize=(8, 6))
    else:
        fig = plt.figure(figsize=(14, 10), )
        gs = fig.add_gridspec(3, 2, width_ratios=[5, 1])
        ax0 = fig.add_subplot(gs[:, 0])

    # Make the scatterplot.
    # Choose a palette.
    pal = sns.diverging_palette(260, 15, n=num_hues, sep=10, center="dark")

    plot = sns.scatterplot(x=main_metric,
                           y=other_metric,
                           ax=ax0,
                           data=dataframe,
                           hue=hue, # 'correct.', # 
                           palette=pal,
                           style=style,
                           legend='brief',
                           s=30)

    # Annotate Regions.
    bb = lambda c: dict(boxstyle="round,pad=0.3", ec=c, lw=2, fc="white")
    func_annotate = lambda  text, xyc, bbc : ax0.annotate(text,
                                                          xy=xyc,
                                                          xycoords="axes fraction",
                                                          fontsize=15,
                                                          color='black',
                                                          va="center",
                                                          ha="center",
                                                          rotation=350,
                                                           bbox=bb(bbc))
    an1 = func_annotate("ambiguous", xyc=(0.9, 0.5), bbc='black')
    an2 = func_annotate("easy-to-learn", xyc=(0.27, 0.85), bbc='r')
    an3 = func_annotate("hard-to-learn", xyc=(0.35, 0.25), bbc='b')


    if not show_hist:
        plot.legend(ncol=1, bbox_to_anchor=[0.175, 0.5], loc='right')
    else:
        plot.legend(fancybox=True, shadow=True,  ncol=1)
    plot.set_xlabel('variability')
    plot.set_ylabel('confidence')

    if show_hist:
        plot.set_title(f"{title}-{model} Data Map", fontsize=17)

        # Make the histograms.
        ax1 = fig.add_subplot(gs[0, 1])
        ax2 = fig.add_subplot(gs[1, 1])
        ax3 = fig.add_subplot(gs[2, 1])

        plott0 = dataframe.hist(column=['confidence'], ax=ax1, color='#622a87')
        plott0[0].set_title('')
        plott0[0].set_xlabel('confidence')
        plott0[0].set_ylabel('density')

        plott1 = dataframe.hist(column=['variability'], ax=ax2, color='teal')
        plott1[0].set_title('')
        plott1[0].set_xlabel('variability')
        plott1[0].set_ylabel('density')

        plot2 = sns.countplot(x="correct.", data=dataframe, ax=ax3, color='#86bf91')
        ax3.xaxis.grid(True) # Show the vertical gridlines

        plot2.set_xticklabels(plot2.get_xticklabels(), rotation=90, ha="right")

        plot2.set_title('')
        plot2.set_xlabel('correctness')
        plot2.set_ylabel('density')

    fig.tight_layout()
    filename = f'{plot_dir}/{title}_{model}.pdf' if show_hist else f'figures/compact_{title}_{model}.pdf'
    fig.savefig(filename, dpi=300)
    logger.info(f"Plot saved to {filename}")
    return 

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--filter",
                        action="store_true",
                        help="Whether to filter data subsets based on specified `metric`.")
    parser.add_argument("--plot",
                        action="store_true",
                        help="Whether to plot data maps and save as `pdf`.")
    # parser.add_argument("--model_dir",
    #                     "-o",
    #                     required=True,
    #                     type=os.path.abspath,
    #                     help="Directory where model training dynamics stats reside.")
    # parser.add_argument("--data_dir",
    #                     "-d",
    #                     default="/Users/swabhas/data/glue/WINOGRANDE/xl/",
    #                     type=os.path.abspath,
    #                     help="Directory where data for task resides.")
    parser.add_argument("--plots_dir",
                        default="./plotout/",
                        type=os.path.abspath,
                        help="Directory where plots are to be saved.")
    parser.add_argument("--task_name",
                        "-t",
                        # default="WINOGRANDE",
                        # choices=("SNLI", "MNLI", "QNLI", "WINOGRANDE"),
                        default="MetaDataset[Cat&Dog]",
                        help="Which task are we plotting or filtering for.")
    parser.add_argument('--metric',
                        choices=('threshold_closeness',
                                'confidence',
                                'variability',
                                'correctness',
                                'forgetfulness'),
                        help="Metric to filter data by.",)
    parser.add_argument("--include_ci",
                        action="store_true",
                        help="Compute the confidence interval for variability.")
    parser.add_argument("--filtering_output_dir",
                        "-f",
                        default="./filtered/",
                        type=os.path.abspath,
                        help="Output directory where filtered datasets are to be written.")
    parser.add_argument("--worst",
                        action="store_true",
                        help="Select from the opposite end of the spectrum acc. to metric,"
                            "for baselines")
    parser.add_argument("--both_ends",
                        action="store_true",
                        help="Select from both ends of the spectrum acc. to metric,")
    parser.add_argument("--burn_out",
                        type=int,
                        default=20, #30,# default=100,
                        help="# Epochs for which to compute train dynamics.")
    parser.add_argument("--model",
                        default="ResNet18",
                        help="Model for which data map is being plotted")
    parser.add_argument('--output_dir', default='outputTmp',
                        help='path where to save, empty for no saving')

    args = parser.parse_args()

    train_dy_metrics = compute_my_dynamics()
    # if args.plot: # Always plot 
    assert args.plots_dir
    if not os.path.exists(args.plots_dir):
        os.makedirs(args.plots_dir)
    plot_data_map(train_dy_metrics, args.plots_dir, title=args.task_name, show_hist=True, model=args.model)