import argparse
import os
from datetime import datetime

import matplotlib.pyplot as plt
import numpy as np
from joblib import Parallel, delayed
from omegaconf import OmegaConf

from icpe.plotter import plot_param, plot_validation
from icpe.train import train
from icpe.utils import make_exp_name

if __name__ == '__main__':

    parser = argparse.ArgumentParser(description="experiment script")

    # optional overrides as dotlist
    parser.add_argument(
        "--cfg",
        nargs="*",
        default=[],
        help="Override config values using dot notation (e.g. lr=0.01)"
    )
    parser.add_argument('--seed', type=int, nargs='+',
                        help='random seed', default=list(range(20)))
    parser.add_argument('-v', '--verbose', action='store_true',
                        help='print training details')
    parser.add_argument('--test', action='store_true',
                        help='run a test trial on a smaller scale')
    parser.add_argument('--suffix', help='suffix for the experiment name',
                        default='')
    args = parser.parse_args()
    if args.test:
        args.seed = [0, 1]
        base_config = OmegaConf.load("test_config.yaml")
    else:
        base_config = OmegaConf.load("config.yaml")
    cli_overrides = OmegaConf.from_dotlist(args.cfg)
    final_config = OmegaConf.merge(base_config, cli_overrides)

    if args.verbose:
        print(OmegaConf.to_yaml(final_config))
        print('seed:', args.seed)

    final_config_dict = OmegaConf.to_container(final_config,
                                               resolve=True,    # resolves interpolations and references
                                               throw_on_missing=True  # raises an error if some values are missing
                                               )
    time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    trial_name = f'{time}_{make_exp_name(final_config_dict)}'
    if args.test:
        trial_name += '_testrun'
    if args.suffix:
        trial_name += f'_{args.suffix}'
    Parallel(n_jobs=-1)(
        delayed(train)(
            seed,
            trial_name,
            final_config_dict
        ) for seed in args.seed
    )
    log_dir = os.path.join('.', 'log', trial_name)
    fig_dir = os.path.join(log_dir, 'figures')
    os.makedirs(fig_dir, exist_ok=True)

    all_Ps = []
    all_Qs = []
    all_msves = []
    for s in args.seed:
        param_dir = os.path.join(log_dir, str(s), 'params')
        # extract the log steps
        steps = [int(f[len("params_"):-len(".npy")])
                 for f in os.listdir(param_dir)
                 if f.startswith("params_") and f.endswith(".npy")]
        steps.sort()

        # extract the parameters
        Ps = []
        Qs = []
        for step in steps:
            P, Q = np.load(os.path.join(param_dir, f'params_{step}.npy'),
                           allow_pickle=True)
            Ps.append(P)
            Qs.append(Q)

        # extract the validation errors
        val_dir = os.path.join(log_dir, str(s), 'validation')
        msves_by_step = []
        for step in steps:
            ns, msves = np.load(os.path.join(val_dir, f'val_{step}.npy'),
                                allow_pickle=True)
            msves_by_step.append(msves)
        msves_by_step = np.array(msves_by_step)
        all_msves.append(msves_by_step)

        all_Ps.append(np.array(Ps))
        all_Qs.append(np.array(Qs))

    mean_Ps = np.mean(np.array(all_Ps), axis=0)
    mean_Qs = np.mean(np.array(all_Qs), axis=0)

    mean_msves = np.mean(np.array(all_msves), axis=0)
    ste_msves = np.std(np.array(all_msves), axis=0) / np.sqrt(len(args.seed))

    for step, P, Q in zip(steps, mean_Ps, mean_Qs):
        plot_param(P, Q)
        plt.savefig(os.path.join(fig_dir, f'params_{step}.png'), dpi=300)
        plt.close()

    for step, msve, ste in zip(steps, mean_msves, ste_msves):
        fig = plot_validation(msve, ste, ns,
                              False)
        plt.savefig(os.path.join(fig_dir, f'val_{step}.png'), dpi=300)
        plt.close()
