import copy
import itertools
import multiprocessing
import os
import shutil

import numpy as np
import torch

from dtn import utils, data
from dtn.train import train


def wrapper(args):
    gpu_list = args.gpu
    torch.set_num_threads(args.threads)
    gpu_idx = int(multiprocessing.current_process().name.split('-')[-1]) - 1
    gpu = gpu_list[gpu_idx // args.workers]
    args_ = copy.copy(args)
    args_.gpu = gpu
    return train(args_)


def train_models(args, workers, threads):
    datasets = args.data
    seeds = args.seeds

    arg_list = []
    for dataset, seed in itertools.product(datasets, seeds):
        args_ = copy.copy(args)
        args_.seed = seed
        args_.data = dataset
        args_.workers = workers
        args_.threads = threads
        arg_list.append(args_)

    res_list = []
    with multiprocessing.Pool(len(args.gpu) * workers) as pool:
        for res in pool.imap_unordered(wrapper, arg_list):
            res_list.append(res)

    return res_list


def write_results(res_list, num_seeds):
    args = res_list[0]['args']

    res_list = sorted(res_list, key=lambda x: x['args'].seed)
    res_list = sorted(res_list, key=lambda x: x['args'].data)
    res_list = sorted(res_list, key=lambda x: x['n_data'], reverse=True)
    res_list = [[e['args'].data, e['args'].seed, *e['values']] for e in res_list]

    with open(os.path.join(args.out_path, 'summary.tsv'), 'w') as f:
        f.write('dataset\tseed\tbest_epoch\ttrn_acc\ttest_acc\n')
        for res in res_list:
            f.write('\t'.join(str(e) for e in res) + '\n')

    with open(os.path.join(args.out_path, 'accuracy.tsv'), 'w') as f:
        f.write('average\tstd\n')
        acc_list = []
        for res in res_list:
            acc_list.append(res[-1])
            if len(acc_list) == num_seeds:
                avg = np.average(acc_list)
                std = np.std(acc_list)
                acc_list = []
                f.write('{}\t{}\n'.format(avg, std))

    return os.path.join(args.out_path, 'summary.tsv')


def print_results(path):
    values = []
    with open(path) as f:
        next(f)
        last_dataset = None
        temp = []
        for line in f:
            words = line.strip().split('\t')
            dataset = words[0]
            if last_dataset is not None and last_dataset != dataset:
                values.append(temp)
                temp = []
            temp.append(float(words[-1]))
            last_dataset = dataset
        values.append(temp)
    values = np.array(values)

    values1 = values[:9, :].mean(0)
    values2 = values[9:46, :].mean(0)
    values3 = values[46:, :].mean(0)
    values4 = values.mean(0)

    print('data\tacc_avg\tacc_std')
    print('large\t{}\t{}'.format(np.mean(values1), np.std(values1)))
    print('medium\t{}\t{}'.format(np.mean(values2), np.std(values2)))
    print('small\t{}\t{}'.format(np.mean(values3), np.std(values3)))
    print('all\t{}\t{}'.format(np.mean(values4), np.std(values4)))


def main():
    args = utils.parse_args()
    if args.data is None:
        args.data = data.get_uci_datasets(path=args.data_path)
    else:
        args.data = [args.data]
    # args.explain = True

    model = args.model
    if model in ['SDT', 'DNDT']:
        args.lr = 1e-2
    elif model in ['DTN', 'DTN-D', 'DTN-S']:
        args.lr = 5e-3
    elif model == 'MLP':
        args.lr = 1e-3
    else:
        raise ValueError(model)

    out = args.model
    if model != 'DNDT':
        out = '{}-{}'.format(out, args.layers)
    if model == 'DSN':
        out = '{}-{}'.format(out, args.activation)
    if args.prune is not None:
        out = '{}-pruned'.format(out)
    args.out_path = os.path.join(args.out_path, out, 'uci')

    if os.path.exists(args.out_path):
        shutil.rmtree(args.out_path)

    workers = 3
    threads = 8
    results = train_models(args, workers, threads)
    path = write_results(results, len(args.seeds))

    if len(args.data) == 121:  # full experiments
        print_results(path)


if __name__ == '__main__':
    main()
