import argparse
import filecmp
import json
import os
import re
import subprocess
from collections import defaultdict, namedtuple
from pathlib import Path

import pandas as pd

import parse_results

Experiment = namedtuple('Experiment', ('task num_type iid_type N attack F'))

ROOT_DIR = Path.home() / 'cofl/aggregation'
JOBS_DIR = ROOT_DIR / 'jobs'
PICKLES_DIR = ROOT_DIR / 'pickles'

# load settings file
SETTINGS = None
SETTINGS_FILE = Path.home() / f"cofl/aggregation/process_experiments/all_experiment_settings.json"
with open(SETTINGS_FILE, 'r') as f:
    SETTINGS = json.load(f)

# DEBUG
LR_CHECK = False
lr_dict = defaultdict(set)
LINES_CHECK = True
lines_dict = defaultdict(set)
RESULTS_TYPES = ['validation'] #, ['validation', 'train']


def remove_old_jobs(tasks, aggs):
    print("Removing old job files...")
    for task in tasks:
        for agg in aggs:
            filenames = f'{task}-{agg}-job_*.sh'
            for filename in Path(JOBS_DIR).glob(filenames):
                os.remove(filename)
    print("Old job files removed.")


def read_json(path):
    """Parses experimental results files."""
    # print("read_json", path)
    validation = []
    with open(path, "r") as f:
        for line in f:
            line = line.strip().replace("'", '"')
            line = line.replace("nan", '"nan"')
            try:
                data = json.loads(line)
            except:
                # print(path)
                print(line)
                raise
            if data['_meta']['type'] in RESULTS_TYPES:
                if data['_meta']['type'] == 'validation':
                    data['_meta']['type'] = 'test'
                validation.append(data)
    return validation


def parse_results_file(filepath, mapping_function, seed):
    result = None
    if os.path.isfile(filepath):
        try:
            entries = read_json(filepath)
            result = list(map(mapping_function, entries))

            if DEBUG:
                if LR_CHECK:
                    debug_filepath = Path(filepath.parent, f"seed{seed}_debug")
                    with open(debug_filepath, "r") as fp:
                        for line in fp:
                            splits = line.split()
                            if splits and splits[0] == 'lr':
                                lr_dict[float(splits[-1])].add(filepath)
                                break
                if LINES_CHECK:
                    lines_count = sum(1 for _ in open(filepath))
                    lines_dict[lines_count].add(filepath)

        except json.decoder.JSONDecodeError:
            pass
    else:
        pass
        # print("FILE NOT FOUND:", filepath)
    return result


def create_job_file(task, agg, job, i):
    with open(JOBS_DIR / f"{task}-{agg}-job_{i // 1000}_{i % 1000}.sh", 'w') as job_file:
        job_file.write(job)


def process_agg_results(task, agg, agg_results, output_type):
    agg_df = []
    i = 0
    for result in agg_results:
        if result is None:
            continue
        if output_type == 'jobs':
            create_job_file(task, agg, result, i)
            i += 1
        elif output_type == 'dataframes':
            agg_df += result

    if output_type == 'jobs' and i:
        print(f"{i} {task}-{agg} jobs created")
        print(f"\nsbatch --array=0-{i - 1} run_{task}_{agg}.sh\n")
    elif output_type == 'dataframes':
        agg_df = pd.DataFrame(agg_df)
        print(f"\n{task}-{agg} df len:", len(agg_df))

        filepath = f'{PICKLES_DIR}/{task}-{agg}.pkl'
        agg_df.to_pickle(filepath)
        print("Saved dataframe to:", filepath, '\n')


def process_rsa(experiment, seed, output_type, base_filepath, base_job):
    mapping_function = lambda x: parse_results.rsa(
        x, experiment, p, rsa_lambda, seed
    )

    rsa_lambdas = list(map(float, SETTINGS['tasks'][experiment.task]['rsa_lambdas']))

    p_norm = [1]
    for p in p_norm:
        for rsa_lambda in rsa_lambdas:
            filepath = Path(base_filepath,
                            f"lambda{rsa_lambda}",
                            f"seed{seed}_stats")
            if output_type == 'jobs':
                if os.path.isfile(filepath):
                    output = None
                else:
                    job = (f"{base_job} "
                           f"--agg rsa "
                           f"--p_norm {p} "
                           f"--rsa_lambda {rsa_lambda} ")
                    output = job
            elif output_type == 'dataframes':
                output = parse_results_file(filepath, mapping_function, seed)

            yield output

def process_fltrust(experiment, seed, output_type, base_filepath, base_job):
    mapping_function = lambda x: parse_results.fltrust(
        x, experiment, p, param1, seed
    )

    param1s = list(map(float, SETTINGS['tasks'][experiment.task]['fltrust_param1s']))
    p_norm = [1]
    for p in p_norm:
        for param1 in param1s:
            filepath = Path(base_filepath,
                            f"param{param1}",
                            f"seed{seed}_stats")
            if output_type == 'jobs':
                if os.path.isfile(filepath):
                    output = None
                else:
                    job = (f"{base_job} "
                           f"--agg fltrust "
                           f"--p_norm {p} "
                           f"--fltrust_param1 {param1} ")
                    output = job
            elif output_type == 'dataframes':
                output = parse_results_file(filepath, mapping_function, seed)

            yield output


def process_krum(experiment, seed, output_type, base_filepath, base_job):
    mapping_function = lambda x: parse_results.krum(
        x, experiment, p, m, reduce_method, n_dims, dims, seed
    )

    krum_m = list(map(int, SETTINGS['tasks'][experiment.task]['krum_m']))
    n_dims = int(SETTINGS['tasks'][experiment.task]['krum_n_dims'])
    n_reduced_dims = list(map(int, SETTINGS['tasks'][experiment.task]['krum_n_reduced_dims']))
    dim_reduce_methods = SETTINGS['tasks'][experiment.task]['krum_dim_reduce_methods']

    p_norm = [1]
    for p in p_norm:
        for m in krum_m:
            for reduce_method in dim_reduce_methods:
                for dims in n_reduced_dims:
                    filepath = Path(base_filepath,
                                    f"m{m}_p{p}_reduce-{str(reduce_method)}_dims{n_dims if str(reduce_method) == 'None' else dims}",
                                    f"seed{seed}_stats")

                    if output_type == 'jobs':
                        if os.path.isfile(filepath):
                            output = None
                        else:
                            job = (f"{base_job} "
                                   f"--agg krum "
                                   f"--p_norm {p} "
                                   f"--krum_m {m} ")

                            if reduce_method is not None and reduce_method != 'None':
                                job += (f"--dim_reduce_method {reduce_method} "
                                        f"--n_reduced_dims {dims} ")
                            output = job

                    elif output_type == 'dataframes':
                        output = parse_results_file(filepath, mapping_function, seed)

                    yield output


def process_cclip(experiment, seed, output_type, base_filepath, base_job):
    mapping_function = lambda x: parse_results.cclip(
        x, experiment, p, tau, momentum, seed
    )

    taus = list(map(float, SETTINGS['tasks'][experiment.task]['cclip_taus']))
    momentums = list(map(float, SETTINGS['tasks'][experiment.task]['cclip_momentums']))

    p_norm = [2, 'inf']
    for p in p_norm:
        for tau in taus:
            for momentum in momentums:
                filepath = Path(base_filepath,
                                f"p{p}_tau{tau}_momentum{momentum}",
                                f"seed{seed}_stats")
                if output_type == 'jobs':
                    if os.path.isfile(filepath):
                        output = None
                    else:
                        job = (f"{base_job} "
                               f"--agg cclip "
                               f"--p_norm {p} "
                               f"--cclip_tau {tau} "
                               f"--cclip_momentum {momentum} ")
                        output = job

                elif output_type == 'dataframes':
                    output = parse_results_file(filepath, mapping_function, seed)

                yield output


def process_avg(experiment, seed, output_type, base_filepath, base_job):
    mapping_function = lambda x: parse_results.avg(
        x, experiment, seed
    )

    filepath = Path(base_filepath, 'avg', f"seed{seed}_stats")
    if output_type == 'jobs':
        if os.path.isfile(filepath):
            output = None
        else:
            job = f"{base_job} --agg avg "
            output = job

    elif output_type == 'dataframes':
        output = parse_results_file(filepath, mapping_function, seed)

    yield output


def process_experiments(tasks, aggs, output_type):
    global FILES_NOT_FOUND

    agg_functions = {'rsa': process_rsa,
                     'krum': process_krum,
                     'cclip': process_cclip,
                     'avg': process_avg,
                     'fltrust': process_fltrust}

    batch_size = int(SETTINGS['batch_size'])
    #num_types = SETTINGS['num_types']
    num_types = ["fixed", "floating"]
    iid_types = SETTINGS['iid_types']
    n = int(SETTINGS['n'])
    attacks = SETTINGS['attacks']
    F = SETTINGS['f']
    seeds = int(SETTINGS['seeds'])

    for task in tasks:
        print(f"{'-'*30} {task} {'-'*30} ")

        epochs = SETTINGS['tasks'][task]['epochs']
        lr = SETTINGS['tasks'][task]['lr']

        base_job = (f"python3 run_experiment.py --use_cuda --batch_size {batch_size} "
                    f" --lr {lr} --epochs {epochs} --task {task} ")

        for agg in aggs:
            if agg == "fltrust":
                n = int(SETTINGS['n']) + 1 # +1 because fltrust requires server gradient computation.
            else:
                n = int(SETTINGS['n'])
            agg_results = []

            for num_type in num_types:
                for iid_type in iid_types:
                    for attack in attacks:
                        if attack == 'None':
                            assert '0' in F
                        for f in F:
                            if ((attack == 'None' and f != '0') or
                                (f == '0' and attack != 'None')):
                                continue

                            experiment = Experiment(
                                task, num_type, iid_type, n, attack, f
                            )

                            base_filepath = Path(ROOT_DIR,
                                                "results",
                                                f"{task}",
                                                f"{epochs}",
                                                f"{num_type}_point",
                                                iid_type,
                                                'max_examples',
                                                agg,
                                                attack,
                                                f"n{n}",
                                                f"f{f}")

                            for seed in range(seeds):
                                experiment_job = (f"{base_job} "
                                                    f"--n {n} "
                                                    f"--attack {attack} "
                                                    f"--f {f} "
                                                    f"--seed {seed} ")
                                if iid_type == 'noniid':
                                    experiment_job += '--noniid '

                                agg_results += [job for job in agg_functions[agg](
                                        experiment,
                                        seed,
                                        output_type,
                                        base_filepath,
                                        experiment_job
                                    )
                                    if job is not None
                                ]

            process_agg_results(task, agg, agg_results, output_type)

        if DEBUG:
            if LR_CHECK:
                print()
                print("---------- LR CHECK ----------")
                for key, val in lr_dict.items():
                    print(f"{key} lr: {len(val)} files")
                if len(lr_dict) == 1:
                    print("Passed")
                else:
                    print("Failed")

            if LINES_CHECK:
                print()
                print("---------- LINES CHECK ----------")
                print('line_count - num_files')
                # for key in sorted(lines_dict.keys()):
                #     print(f"{key} lines: {len(lines_dict[key])} files")

                # print line counts
                for key in lines_dict:
                    if key == int(epochs) * 2:
                        continue
                    for file in lines_dict[key]:

                        # print('rm', file)
                        last_line = subprocess.check_output(['tail', '-1', file], universal_newlines=True)
                        try:
                            epoch = str(re.search(r"'E': (.*?), 'Length'",
                                                last_line).group(1))
                            loss = float(re.search(r"'validation'.*'Loss': (.*?), 'top1'",
                                                last_line).group(1))
                            top1 = float(re.search(r"'top1': (.*?)\}",
                                                last_line).group(1))

                            if loss != 'nan' and loss!= 'inf':
                                if loss < 1e+22:
                                    print('rm', file)
                                    # print(epoch, loss, top1)
                        except AttributeError:
                            pass
                            # print('-'*80)
                            # print(last_line)


def main():
    global DEBUG

    parser = argparse.ArgumentParser()

    parser.add_argument('--tasks', nargs='+', required=True,
                        choices=['all', 'mnist', 'emnist', 'cifar100'])

    parser.add_argument('--aggs', nargs='+', required=True,
                        choices=['all', 'rsa', 'krum', 'cclip', 'avg', 'fltrust'])

    output_type = parser.add_mutually_exclusive_group(required=False)
    output_type.add_argument('--jobs',action='store_true')
    output_type.add_argument('--dataframes',action='store_true')

    parser.add_argument('--debug', action='store_true')

    args = parser.parse_args()
    DEBUG = args.debug

    tasks = ['mnist', 'emnist', 'cifar100'] if 'all' in args.tasks else args.tasks
    aggs = ['rsa', 'krum', 'cclip', 'avg','fltrust'] if 'all' in args.aggs else args.aggs
    output_type = None
    if args.jobs:
        output_type = 'jobs'
    elif args.dataframes:
        output_type = 'dataframes'

    #remove_old_jobs(tasks, aggs)
    if output_type:
        process_experiments(tasks, aggs, output_type)

    print(f"{'-'*79}\nAll experiments processed!")


if __name__ == "__main__":
    main()
