import argparse
import collections
import glob
import itertools
import json
import logging
import math
import os
import pdb
import re
import statistics
import sys
import traceback

import matplotlib.pyplot as plt
import torch
import tensorflow as tf

logging.basicConfig(format='%(asctime)s [%(levelname)s] (%(name)s):  %(message)s', datefmt='%Y-%m-%d %H:%M:%S', level=logging.INFO)
logger = logging.getLogger(__name__)

def overlap_coefficient(set1, set2):
    """Szymkiewicz-Simpson overlap"""
    if len(set1) == 0 or len(set2) == 0:
        return float('nan')
    return len(set1.intersection(set2)) / min(len(set1), len(set2))

def jaccard(set1, set2):
    """Jaccard overlap"""
    return len(set1.intersection(set2)) / len(set1.union(set2))

def overlap_number(set1, set2):
    """Raw number of overlapped examples"""
    return len(set1.intersection(set2))

dsize_cache = dict()

class Trials:
    def __init__(self, experiment_id, experiment_dirs, tolerate_errors=False):
        self.experiment_id = experiment_id
        self.experiment_dirs = experiment_dirs
        self.tolerate_errors = tolerate_errors



        self._cached_acquisition_data = dict()
        self._training_args = None
        self._config = None
        self._data_dirs = None
        self._event_files = None
        self._acquisition_logs = None

        if self.tolerate_errors:
            self._prune_bad_experiment_dirs()

        want_tags = ['eval_eval_acc', 'test_eval_acc', 'bysize_test_acc', 'bysize_avg_train_loss']
        self.taglists = files_to_taglists(self.get_event_files(), want_tags, aggregators={'default': max})

    def get_acquisition_logs(self):
        if self._acquisition_logs is None:
            logs = dict()
            for d in self.experiment_dirs:
                try:
                    with open(os.path.join(d, 'output', 'selected_indexes.jsonl')) as f:
                        logs[d] = [json.loads(line) for line in f]
                except FileNotFoundError as e:
                    return None
            self._acquisition_logs = logs
        return self._acquisition_logs

    def get_event_files(self):
        if self._event_files is None:
            self._event_files = event_files_from_experiment_dirs(self.experiment_dirs)
            if any(ev is None for ev in self._event_files):
                logger.error("Expected exactly one event file but found 0 or >1 for dir {}".format(self.experiment_dirs[self._event_files.index(None)]))
                self._event_files = None
        return self._event_files

    def _prune_bad_experiment_dirs(self):
        # Simple way to attempt to tolerate errors is by removing specific
        # experiment directories that don't match expectations.  This could
        # potentially be replaced or supplemented with case-specific error
        # correction in each call (but that could create inconsistencies if
        # some exp dirs are used for one call but not others)
        new_experiment_dirs = []
        for expdir in self.experiment_dirs:
            try:
                if os.path.exists(os.path.join(expdir, 'output')):
                    tfile = os.path.join(expdir, 'output', 'training_args.bin')
                else:
                    #raise
                    tfile = os.path.join(expdir, 'training_args.bin')
                torch.load(tfile).data_dir
            except Exception as e:
                logger.warning('Problem with "{}".  Pruning it.'.format(expdir))
                continue

            new_experiment_dirs.append(expdir)
        self.experiment_dirs = new_experiment_dirs

    def get_bysize_final_accuracies(self):
        if not all((filename, 'bysize_test_acc') in self.taglists for filename in self.get_event_files()):
            return None

        bysize_finalaccs = []
        for filename in self.get_event_files():
            bysize_finalaccs.append(self.taglists[(filename, 'bysize_test_acc')][-1][1])
        return bysize_finalaccs

    def get_bysize_test_areas_under_curves(self, adjustment='none'):
        # First try using acquisition log since that's more reliable
        curvedata = None
        logs = self.get_acquisition_logs()
        if logs:
            logs = logs.values()
            curvedata = []
            for log in logs:
                cdata = []
                for rec in log:
                    cdata.append((rec['labeled_dataset_size'], max(rec['all_train_results'], key=lambda x: x['best_dev_acc'])['test_acc']))
                curvedata.append(cdata)
        else:
            # Fall back to event files
            if not all((filename, 'bysize_test_acc') in self.taglists for filename in self.get_event_files()):
                return None
            curvedata = [self.taglists[(filename, 'bysize_test_acc')] for filename in self.get_event_files()]
        assert len(curvedata) == len(self.get_event_files())

        aucs = []
        for cdata in curvedata:
            if adjustment == 'monotonic':
                aucs.append(trapezoid_auc_fixfail(cdata))
            elif adjustment == 'convex':
                aucs.append(trapezoid_auc(get_convex_hull(cdata)))
            elif adjustment == 'none':
                aucs.append(trapezoid_auc(cdata))
            else:
                raise ValueError("Invalid adjustment")

        return aucs

    def get_per_train_curves(self, keys):
        all_curves = []
        for expdir in self.experiment_dirs:
            event_files = dict()
            traindirs = []
            if os.path.exists(os.path.join(expdir, 'output')):
                traindirs = glob.glob(os.path.join(expdir, 'output', 'tmp_train_output_*'))
            else:
                traindirs = glob.glob(os.path.join(expdir, 'tmp_train_output_*'))
            for traindir in traindirs:
                m = re.match(r'tmp_train_output_(?P<trainnum>[0-9]+)(_r(?P<trial>[0-9]+))?', os.path.basename(traindir))
                if not m:
                    raise ValueError("Bad train dir format")
                trainnum = int(m.group('trainnum'))
                trialnum = int(m.group('trial')) if m.group('trial') else 0
                evfiles = glob.glob(os.path.join(traindir, 'tfevents', 'ev*'))

                if trainnum not in event_files:
                    event_files[trainnum] = dict()
                assert trialnum not in event_files[trainnum]

                if len(evfiles) != 1:
                    event_files[trainnum][trialnum] = None
                    logger.error("Expected exactly one event file but found 0 or >1 for dir {} train {} trial {}".format(expdir, trainnum, trialnum))
                    return None

                taglists = files_to_taglists([evfiles[0]], keys, aggregators={'default': max})
                event_files[trainnum][trialnum] = {k: taglists[(evfiles[0], k)] for k in keys}

            all_curves.append(event_files)

        return all_curves

    def get_curves(self, key='bysize_test_acc'):
        curves = []
        for filename in self.get_event_files():
            curves.append(self.taglists[(filename, key)].copy())

        return curves


    def get_elapsed_times(self):
        exptimes = []
        for expdir in self.experiment_dirs:
            with open(os.path.join(expdir, 'post_run_info.json')) as f:
                exptimes.append(json.load(f).get('total_runtime_seconds', None))
            #exptimes.append(os.stat(os.path.join(expdir, 'output', 'logging_output.log')).st_mtime - os.stat(os.path.join(expdir, 'config.json')).st_mtime)
        return exptimes

    def _get_acquisition_data(self, exp_dir, use_cached=True):
        if exp_dir not in self._cached_acquisition_data or not use_cached:
            sel_idx_path = 'selected_indexes.jsonl' if 'multichoice_swag' in exp_dir else os.path.join('output', 'selected_indexes.jsonl')
            with open(os.path.join(exp_dir, sel_idx_path)) as f:
                self._cached_acquisition_data[exp_dir] = [json.loads(line) for line in f.readlines()]
        return self._cached_acquisition_data[exp_dir]

    def _get_final_selected_indexes(self, exp_dir):
        # Note: assumes the largest selected index set is the final one.  This
        # is likely a durable assumption but technically could be inaccurate.
        sel_idx_data = self._get_acquisition_data(exp_dir)
        return set(max([x['selected_indexes'] for x in sel_idx_data], key=len))

    def get_selected_example_overlaps(self, overlap_fn=overlap_number):
        selected_index_sets = []
        for exp_dir in self.experiment_dirs:
            selected_index_sets.append(self._get_final_selected_indexes(exp_dir))

        allvals = []
        for set1, set2 in itertools.combinations(selected_index_sets, 2):
            allvals.append(overlap_fn(set1, set2))

        return allvals

    def get_training_args(self):
        if self._training_args is None:
            self._training_args = self._load_training_args(0)
        return self._training_args

    def _load_training_args(self, exp_id):
        path = os.path.join(self.experiment_dirs[0], 'output', 'training_args.bin')
        if not os.path.exists(path):
            path = os.path.join(self.experiment_dirs[0], 'training_args.bin')
        return torch.load(path)

    def get_config(self):
        if self._config is None:
            with open(os.path.join(self.experiment_dirs[0], 'config.json')) as f:
                self._config = json.load(f)
        return self._config

    def _get_data_dirs_via_training_args(self):
        return [self._load_training_args(i).data_dir for i, expdir in enumerate(self.experiment_dirs)]

    def _get_data_dirs_via_config(self):
        data_dirs = []
        for expdir in self.experiment_dirs:
            with open(os.path.join(expdir, 'config.json')) as f:
                data_dirs.append(json.load(f)['data_dir'])
        return data_dirs

    def get_data_dirs(self):
        if self._data_dirs is None:
            self._data_dirs = self._get_data_dirs_via_training_args()
        return self._data_dirs

    def get_data_dir(self):
        data_dir = os.path.commonpath(self.get_data_dirs())
        if data_dir == '':
            raise ValueError("Couldn't figure out common data dir")

        return data_dir

    def get_dataset_size(self):
        global dsize_cache
        data_sizes = []
        for ddir in self.get_data_dirs():
            if ddir not in dsize_cache:
                with open(os.path.join(ddir, 'train.jsonl'), 'r') as f:
                    dsize_cache[ddir] = len(f.readlines())
            data_sizes.append(dsize_cache[ddir])

        data_size = data_sizes[0]
        if not all(ds == data_size for ds in data_sizes):
            raise ValueError("Data sizes should be same")

        return data_size

class ExperimentDirectoryFinder:
    def __init__(self, base_dir):
        self.base_dir = base_dir
        self._dir_listing = next(os.walk(self.base_dir))[1]

    def experiment_dirs_by_slug(self, experiment_slug):
        pat = re.compile(r'^' + re.escape(experiment_slug) + r'_r[0-9]+(_.*)?$')
        return [os.path.join(self.base_dir, d) for d in self._dir_listing if pat.match(d)]

def trapezoid_auc(points):
    """Trapezoidal area under curve for the given points.  The `points`
    argument is expected to be a list of (x, y) pairs."""
    s = 0
    for i in range(1,len(points)):
        cur = points[i]
        prev = points[i-1]
        base = cur[0] - prev[0]
        avg_height = (cur[1] + prev[1]) / 2
        area = base * avg_height
        s += area
    s /= points[-1][0] - points[0][0]
    return s


def signed_point_to_line_distance(p, p1, p2):
    """Distance between point and line, but signed so that if you have a line
    defined with points going from low x to high x, the result is positive if
    the given point is below the line and negative if above it."""
    return (((p2[0] - p1[0])*(p1[1] - p[1])) - ((p1[0] - p[0])*(p2[1] - p1[1]))) / math.sqrt((p2[0] - p1[0])**2 + (p2[1] - p1[1])**2)

def get_convex_hull(points):
    points = sorted(points, key=lambda x: x[0])
    if len(points) <= 2:
        return points

    # Find point maximally far from the line between endpoints (point must also
    # be above the line, hence signed distance).  This is guaranteed to be part
    # of the convex hull, because there is no way for that point to be included
    # in a hull made from other points.
    mid_idx = max(range(1, len(points)-1), key=lambda i: -signed_point_to_line_distance(points[i], points[0], points[-1]))
    midp = points[mid_idx]

    # If no point is above the line, we're done
    if signed_point_to_line_distance(midp, points[0], points[-1]) >= 0:
        return [points[0], points[-1]]

    # Otherwise, recurse
    left = get_convex_hull(points[:mid_idx+1])
    right = get_convex_hull(points[mid_idx:])

    # We included the midpoint twice, so prune it out when joining
    return left + right[1:]

def trapezoid_auc_linear_convex(points):
    return trapezoid_auc(get_convex_hull(points))

def trapezoid_auc_fixfail(points):
    """Trapezoidal area under curve for the given points, with an adjustment to
    fix "fails-to-train" cases by interpolating from other points.  The
    `points` argument is expected to be a list of (x, y) pairs."""
    s = 0
    best = [-1, -1]
    for i in range(1,len(points)):
        cur = points[i]
        prev = points[i-1]
        if best[1] > prev[1]:
            prev = (prev[0], best[1])
        if best[1] > cur[1]:
            cur = (cur[0], best[1])
        else:
            best[1] = cur[1]
        base = cur[0] - prev[0]
        avg_height = (cur[1] + prev[1]) / 2
        area = base * avg_height
        s += area
    s /= points[-1][0] - points[0][0]
    return s

def experiment_id_from_slug(experiment_slug):
    return re.match(r'^.*(auto|test|_v)(?P<exp_id>[0-9]+)$', experiment_slug).group('exp_id')

def files_to_taglists(event_files, want_tags, aggregators=None):
    if aggregators is None:
        aggregators = {'default': lambda x: sum(x)/len(x)}

    loss_point_lists = dict()
    for filename in event_files:
        for summary in tf.compat.v1.train.summary_iterator(filename):
            if summary.summary.value is None or len(summary.summary.value) == 0:
                continue
            summary_tag = summary.summary.value[0].tag
            if summary_tag not in want_tags:
                continue

            summary_val = summary.summary.value[0].simple_value

            summary_key = (filename, summary_tag)
            if summary_key not in loss_point_lists:
                loss_point_lists[summary_key] = []

            loss_point_lists[summary_key].append((summary.step, summary_val))

    # Aggregate values that occurred in the same step
    for filename, summary_tag in loss_point_lists.keys():
        aggregator_fn = aggregators.get(summary_tag, aggregators['default'])
        buf = []
        new_lst = []
        for step, val in loss_point_lists[(filename, summary_tag)]:
            if len(buf) == 0 or step == buf[0][0]:
                buf.append((step, val))
            else:
                new_val = aggregator_fn([x[1] for x in buf])
                new_lst.append((buf[0][0], new_val))
                buf = [(step, val)]
        if len(buf) != 0:
            new_val = aggregator_fn([x[1] for x in buf])
            new_lst.append((buf[0][0], new_val))
        loss_point_lists[(filename, summary_tag)] = new_lst

    return loss_point_lists

def get_tag_loss_lists(loss_point_lists):
    tag_loss_lists = dict()
    for summary_key, loss_points in loss_point_lists.items():
        filename, tag = summary_key
        x, y = zip(*loss_points)
        if tag not in tag_loss_lists:
            tag_loss_lists[tag] = {'max_vals': [], 'final_vals': []}
        tag_loss_lists[tag]['max_vals'].append(max(y))
        tag_loss_lists[tag]['final_vals'].append(y[-1])
    return tag_loss_lists

def get_tagstats(loss_point_lists):
    tag_loss_lists = get_tag_loss_lists(loss_point_lists)

    tag_stats = dict()
    for tag in tag_loss_lists:
        tag_stats[tag] = dict()
        for statkey in tag_loss_lists[tag]:
            vals = tag_loss_lists[tag][statkey]
            mean = statistics.mean(vals)
            std = statistics.stdev(vals) if len(vals) >= 2 else float('nan')
            tag_stats[tag][statkey] = {'mean': mean, 'std': std}

    return tag_stats
    

def event_files_from_experiment_dirs(experiment_dirs):
    event_files = []
    for expdir in experiment_dirs:
        if os.path.exists(os.path.join(expdir, 'output')):
            evfiles = glob.glob(os.path.join(expdir, 'output', 'tfevents', 'ev*'))
        else:
            evfiles = glob.glob(os.path.join(expdir, 'tfevents', 'ev*'))
        if len(evfiles) != 1:
            event_files.append(None)
        else:
            event_files.append(evfiles[0])
    return event_files


def print_slug_record(experiment_slug, output_format, finder, tolerate_errors=False):
    experiment_id = experiment_id_from_slug(experiment_slug)

    experiment_dirs = finder.experiment_dirs_by_slug(experiment_slug)
    if len(experiment_dirs) == 0:
        logger.error("Failed to find experiment dirs for slug {}".format(experiment_slug))
        return

    trials = Trials(experiment_id, experiment_dirs, tolerate_errors=tolerate_errors)

    run_filename = None
    with open(os.path.join(experiment_dirs[0], 'meta_config.json')) as f:
        meta_config = json.load(f)
        run_filename = meta_config['main_pyfile']

    al_score_type = None
    mc_dropout_override = -1
    for expdir in experiment_dirs:
        with open(os.path.join(expdir, 'config.json')) as f:
            config = json.load(f)
            new_al_score_type = config['run_config'].get('al_config', dict()).get('score_method', None)
            if run_filename == 'run_bertcls_svm.py':
                new_al_score_type = config['run_config'].get('al_config', dict()).get('selection_type', None)
            if new_al_score_type == 'greedy_coreset' and config['run_config'].get('al_config', dict()).get('coreset_retraining', False):
                new_al_score_type = 'greedy_coreset_retraining'
            if not (al_score_type is None or al_score_type == new_al_score_type):
                raise ValueError("Inconsistent al score type")
            al_score_type = new_al_score_type

            new_mcoverride = config['run_config'].get('al_config', dict()).get('mcdropout_dropout_override', -1)
            if not (mc_dropout_override == -1 or mc_dropout_override == new_mcoverride):
                raise ValueError("Inconsistent mc_dropout_override")
            mc_dropout_override = new_mcoverride

    total_exptime = 0
    exptimes = trials.get_elapsed_times()
    if all(x is not None for x in exptimes):
        total_exptime = sum(exptimes)

    test_auc_mean, test_auc_std = float('nan'), float('nan')
    aucs = trials.get_bysize_test_areas_under_curves()
    if aucs is not None:
        test_auc_mean = statistics.mean(aucs)
        test_auc_std = statistics.stdev(aucs) if len(aucs) >= 2 else float('nan')

    test_fixfailauc_mean, test_fixfailauc_std = float('nan'), float('nan')
    aucs_fixfail = trials.get_bysize_test_areas_under_curves(adjustment='monotonic')
    if aucs_fixfail is not None:
        test_fixfail_auc_mean = statistics.mean(aucs_fixfail)
        test_fixfail_auc_std = statistics.stdev(aucs_fixfail) if len(aucs_fixfail) >= 2 else float('nan')

    test_convexauc_mean, test_convexauc_std = float('nan'), float('nan')
    aucs_convex = trials.get_bysize_test_areas_under_curves(adjustment='convex')
    if aucs_convex is not None:
        test_convex_auc_mean = statistics.mean(aucs_convex)
        test_convex_auc_std = statistics.stdev(aucs_convex) if len(aucs_convex) >= 2 else float('nan')

    bysize_test_finalacc_mean, bysize_test_finalacc_std = float('nan'), float('nan')
    bysize_finalaccs = trials.get_bysize_final_accuracies()
    if bysize_finalaccs is not None:
        bysize_test_finalacc_mean = statistics.mean(bysize_finalaccs)
        bysize_test_finalacc_std = statistics.stdev(bysize_finalaccs) if len(bysize_finalaccs) >= 2 else float('nan')

    model_config = trials.get_config()
    model_name_or_path = model_config['model_name_or_path']
    if model_name_or_path is None:
        model_name_or_path = model_config['model_type']
    learning_rate = model_config['learning_rate']
    max_steps = model_config['max_steps']
    do_lower_case = model_config['do_lower_case']

    overlap_mean, overlap_std, overlap_median = None, None, None
    overlaps = trials.get_selected_example_overlaps()
    if len(overlaps) != 0:
        overlap_mean, overlap_std, overlap_median = statistics.mean(overlaps), statistics.stdev(overlaps) if len(overlaps) >= 2 else float('nan'), statistics.median(overlaps)

    record = collections.OrderedDict()
    record['experiment_id'] = experiment_id
    record['num_trials'] = len(trials.experiment_dirs)
    record['run_filename'] = run_filename
    record['total_exptime'] = total_exptime
    record['all_test_aucs'] = aucs
    record['all_test_convex_aucs'] = aucs_convex
    record['bysize_test_finalacc_mean'] = bysize_test_finalacc_mean
    record['bysize_test_finalacc_std'] = bysize_test_finalacc_std
    record['test_trapauc_mean'] = test_auc_mean
    record['test_trapauc_std'] = test_auc_std
    record['test_trapfixfailauc_mean'] = test_fixfail_auc_mean
    record['test_trapfixfailauc_std'] = test_fixfail_auc_std
    record['test_convex_auc_mean'] = test_convex_auc_mean
    record['test_convex_auc_std'] = test_convex_auc_std
    record['overlap_mean'] = overlap_mean
    record['overlap_std'] = overlap_std
    record['overlap_median'] = overlap_median
    record['model'] = model_name_or_path
    record['freeze_core_weights'] = model_config['run_config'].get('optimizer', dict()).get('freeze_core_weights', False)
    record['learning_rate'] = learning_rate
    record['al_score_type'] = al_score_type
    record['mc_dropout_override'] = mc_dropout_override
    record['max_steps'] = max_steps
    record['do_lower_case'] = do_lower_case
    record['al_batch_size'] = model_config['run_config'].get('al_config', dict()).get('refill_increment', None)
    record['dataset'] = os.path.relpath(trials.get_data_dir(), finder.base_dir)
    record['dataset_size'] = trials.get_dataset_size()
    record['total_exptime'] = total_exptime
    record['num_train_retries'] = model_config.get('num_train_retries', 1)

    if output_format == 'gnurec':
        print()
        for key in record.keys():
            print('{}: {}'.format(key, str(record[key])))
    elif output_format == 'jsonl':
        print(json.dumps(record))

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--experiments_dir', required=True)
    parser.add_argument('--output_format', default='jsonl', choices=['gnurec', 'jsonl'])
    parser.add_argument('--attempt_error_correction', default=False,
            action='store_true', help='Attempt to correct for errors in parsing/summarizing experiment results.  Less likely to crash, but more likely to cause silent inaccuracies or omit some experiment trials')
    parser.add_argument('experiment_slug', nargs='+')
    args = parser.parse_args(sys.argv[1:])

    finder = ExperimentDirectoryFinder(args.experiments_dir)

    for experiment_slug in args.experiment_slug:
        try:
            print_slug_record(experiment_slug, args.output_format, finder, tolerate_errors=args.attempt_error_correction)
        except Exception as e:
            logger.error("Error processing {}: {}".format(experiment_slug, traceback.format_exception_only(e.__class__, e)[0].strip('\n')))
    plt.legend()
    #plt.savefig('/dev/shm/tmp.svg')

