import argparse
import json
import os
import random
from datetime import datetime
from pathlib import Path

import numpy as np
import torch

import data
import train


def test(args, algorithm, seed, eval_on):

    # Get data
    train_loader, train_eval_loader, val_loader, test_loader = data.get_loaders(
        args)

    stats = {}
    loaders = {'train': train_eval_loader,
               'val': val_loader,
               'test': test_loader}
    for split in eval_on:
        set_seed(seed + 10, args.cuda)
        loader = loaders[split]
        split_stats = train.eval_latent_tasks(
            args, algorithm, loader, split=split)
        stats[split] = split_stats

    return stats


def get_parser():
    # Arguments
    parser = argparse.ArgumentParser()

    # Train / test
    parser.add_argument('--train', type=int, default=1, help="Train models")
    parser.add_argument('--test', type=int, default=1, help="Test models")
    # Only applicable when train is 0 and test is 1
    parser.add_argument('--ckpt_folders', type=str, nargs='+')

    parser.add_argument('--progress_bar', type=int,
                        default=0, help="Test models")

    # Training / Optimization args
    parser.add_argument('--num_epochs', type=int,
                        default=200, help='Number of epochs')
    parser.add_argument('--optimizer', type=str, default='adam',
                        choices=['sgd', 'adam'])
    parser.add_argument('--learning_rate', type=float, default=1e-4)
    parser.add_argument('--weight_decay', type=float, default=0)

    # Data args
    parser.add_argument('--dataset', type=str, default='femnist',
                        choices=['femnist', 'cifar-c', 'tinyimg'])
    parser.add_argument('--data_dir', type=str, default='data/')

    # Data sampling
    parser.add_argument('--sampler', type=str, default='group',
                        choices=['standard', 'group'],
                        help='Standard or group sampler')
    parser.add_argument('--uniform_over_groups', type=int, default=1,
                        help='Sample across groups uniformly')
    parser.add_argument('--meta_batch_size', type=int,
                        default=5, help='Number of domains')
    parser.add_argument('--support_size', type=int,
                        default=20, help='Support size')
    parser.add_argument('--shuffle_train', type=int, default=1,
                        help='Only relevant when no group sampling = 0 \
                        and --uniform_over_groups 0')
    parser.add_argument('--drop_last', type=int, default=0)
    parser.add_argument('--loading_type', type=str, choices=['PIL', 'jpeg'], default='jpeg',
                        help='Whether to use PIL or jpeg4py when loading images. Jpeg is faster.')

    parser.add_argument('--num_workers', type=int, default=4,
                        help='Num workers for pytorch data loader')
    parser.add_argument('--pin_memory', type=int, default=1, help='Pytorch loader pin memory. \
                        Best practice is to use this')

    # Model args
    parser.add_argument('--model', type=str, default='convnet',
                        choices=['resnet50', 'convnet'])
    parser.add_argument('--pretrained', type=int,
                        default=1, help='Pretrained resnet')

    # Method
    parser.add_argument('--algorithm', type=str, default='ERM',
                        choices=['ERM', 'ARM-CML', 'ARM-BN', 'CXDA'])

    # ARM-CML
    parser.add_argument('--n_context_channels', type=int,
                        default=3, help='Used when using a convnet/resnet')
    parser.add_argument('--context_net', type=str, default='convnet')

    # Evaluation
    parser.add_argument('--epochs_per_eval', type=int, default=1)

    # Test
    parser.add_argument('--eval_on', type=str, nargs="*", default=['test'])

    # Logging
    parser.add_argument('--seeds', type=int, nargs="*",
                        default=[0], help='Seeds')
    parser.add_argument('--plot', type=int, default=0, help='Plot or not')
    parser.add_argument('--exp_name', type=str, default='')
    parser.add_argument('--debug', type=int, default=0)
    parser.add_argument('--output_path', type=str, default='output')

    # CXDA
    parser.add_argument('--supervised', type=int, default=0)

    return parser


def set_seed(seed, cuda):

    # Make as reproducible as possible.
    # Please note that pytorch does not let us make things completely reproducible across machines.
    # See https://pytorch.org/docs/stable/notes/randomness.html
    print('setting seed', seed)
    torch.manual_seed(seed)
    if cuda:
        torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)


class ScoreKeeper:

    def __init__(self, splits, n_seeds):

        self.splits = splits
        self.n_seeds = n_seeds

        self.results = {}
        for split in splits:
            self.results[split] = {}

    def log(self, stats):
        for split in stats:
            split_stats = stats[split]
            for key in split_stats:
                value = split_stats[key]
                metric_name = key.split('/')[1]

                if metric_name not in self.results[split]:
                    self.results[split][metric_name] = []

                self.results[split][metric_name].append(value)

    def print_stats(self, metric_names=['worst_case_acc', 'average_acc', 'empirical_acc']):

        for split in self.splits:
            print("Split: ", split)

            for metric_name in metric_names:

                values = np.array(self.results[split][metric_name])
                avg = np.mean(values)
                standard_error = np.std(values) / np.sqrt(self.n_seeds - 1)

                print(f"{metric_name}: {avg}, standard error: {standard_error}")


if __name__ == '__main__':

    start_time = datetime.now()

    args = get_parser().parse_args()

    # Cuda
    if torch.cuda.is_available():
        args.device = torch.device('cuda')
        args.cuda = True
    else:
        args.device = torch.device('cpu')
        args.cuda = False

    # For reproducibility
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    stats_dict = {'num_epochs': args.num_epochs,
                  'epochs_per_eval': args.epochs_per_eval}

    if args.train:

        score_keeper = ScoreKeeper(args.eval_on, len(args.seeds))
        print("args seeds: ", args.seeds)
        ckpt_dirs = []

        for ind, seed in enumerate(args.seeds):
            print("seed: ", seed)
            set_seed(seed, args.cuda)
            tags = ['supervised', args.dataset, args.algorithm]

            # Save folder
            datetime_now = datetime.now().strftime("%Y%m%d-%H%M%S")
            name = args.dataset + args.exp_name + '_' + str(seed)
            args.ckpt_dir = Path(args.output_path) / \
                'checkpoints' / f'{name}_{datetime_now}'
            ckpt_dirs.append(args.ckpt_dir)
            print("CKPT DIR: ", args.ckpt_dir)

            if args.debug:
                tags.append('debug')

            stats_per_epoch = train.train(args)

            # Test the model just trained on
            if args.test:
                args.ckpt_path = args.ckpt_dir / f'best.pkl'
                algorithm = torch.load(args.ckpt_path).to(args.device)
                stats = test(args, algorithm, seed, eval_on=args.eval_on)
                score_keeper.log(stats)

            # Store the stats for each seed separately
            stats_dict[seed] = {}
            stats_dict[seed]['train_evaluation'] = stats_per_epoch
            stats_dict[seed]['test_evaluation'] = stats
        print("Ckpt dirs: \n ", ckpt_dirs)
        score_keeper.print_stats()

    elif args.test and args.ckpt_folders:  # Test a set of already trained models

        # Check if checkpoints exist
        for ckpt_folder in args.ckpt_folders:

            ckpt_path = Path(args.output_path) / \
                'checkpoints' / ckpt_folder / f'best.pkl'
            algorithm = torch.load(ckpt_path)
            print("Found: ", ckpt_path)

        score_keeper = ScoreKeeper(args.eval_on, len(args.ckpt_folders))
        for i, ckpt_folder in enumerate(args.ckpt_folders):

            # Test algorithm
            seed = args.seeds[i]
            args.ckpt_path = Path(args.output_path) / 'checkpoints' / \
                ckpt_folder / f'best.pkl'  # best.pkl
            algorithm = torch.load(args.ckpt_path).to(args.device)
            stats = test(args, algorithm, seed, eval_on=args.eval_on)
            score_keeper.log(stats)

        score_keeper.print_stats()

        stats_dict[seed] = {}
        stats_dict[seed]['test_evaluation'] = stats

    end_time = datetime.now()
    runtime = (end_time - start_time).total_seconds() / 60.0
    print("\nTotal runtime: ", runtime)

    stats_dict['run_time'] = runtime

    # Store statistics
    with open('output/results/' + args.exp_name + '.json', 'w') as f:
        json.dump(stats_dict, f)