# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

import argparse
from logging import getLogger
import pickle
import os
import pathlib
import random

import numpy as np
import torch

import matplotlib
import matplotlib.pyplot as plt

from .logger import create_logger, PD_Stats

import torch.distributed as dist

import sys
sys.path.insert(1, os.path.join(sys.path[0], '../../..'))
sys.path.insert(1, os.path.join(sys.path[0], '../..'))
from examples.pretraining.swav.src.config import DATASET_DEFAULTS
from examples.configs.utils import populate_config


matplotlib.use('Agg')

FALSY_STRINGS = {"off", "false", "0"}
TRUTHY_STRINGS = {"on", "true", "1"}


logger = getLogger()


def bool_flag(s):
    """
    Parse boolean arguments from the command line.
    """
    if s.lower() in FALSY_STRINGS:
        return False
    elif s.lower() in TRUTHY_STRINGS:
        return True
    else:
        raise argparse.ArgumentTypeError("invalid value for a boolean flag")


def init_distributed_mode(args):
    """
    Initialize the following variables:
        - world_size
        - rank
    """
    if args.cpu_only:
        return

    if args.is_not_slurm_job:
        args.is_slurm_job = False
    else:
        args.is_slurm_job = "SLURM_JOB_ID" in os.environ

    if args.is_slurm_job:
        args.rank = int(os.environ["SLURM_PROCID"])
        args.world_size = int(os.environ["SLURM_NNODES"]) * int(
            os.environ["SLURM_TASKS_PER_NODE"][0]
        )
    else:
        # multi-GPU job (local or multi-node) - jobs started with torch.distributed.launch
        # read environment variables
        args.rank = int(os.environ["RANK"])
        args.world_size = int(os.environ["WORLD_SIZE"])

    logger.info("\n" + "=" * 50)
    logger.info(f"rank={args.rank}, world_size={args.world_size}")
    logger.info("=" * 50)

    # prepare distributed
    dist.init_process_group(
        backend="nccl",
        init_method=args.dist_url,
        world_size=args.world_size,
        rank=args.rank,
    )

    # set cuda device
    args.gpu_to_work_on = args.rank % torch.cuda.device_count()
    torch.cuda.set_device(args.gpu_to_work_on)
    return


def initialize_exp(params, *args, dump_params=True):
    """
    Initialize the experience:
    - dump parameters
    - create checkpoint repo
    - create a logger
    - create a panda object to keep track of the training statistics
    """

    # dump parameters
    if dump_params:
        pickle.dump(params, open(os.path.join(params.log_dir, "params.pkl"), "wb"))

    # create repo to store checkpoints
    params.dump_checkpoints = os.path.join(params.log_dir, "checkpoints")
    if not params.rank and not os.path.isdir(params.dump_checkpoints):
        os.mkdir(params.dump_checkpoints)

    # create a panda object to log loss and acc
    training_stats = PD_Stats(
        os.path.join(params.log_dir, "stats" + str(params.rank) + ".pkl"), args
    )

    # create a logger
    logger = create_logger(
        os.path.join(params.log_dir, "train.log"), rank=params.rank
    )
    logger.info("============ Initialized logger ============")
    logger.info(
        "\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(params)).items()))
    )
    logger.info("The experiment will be stored in %s\n" % params.log_dir)
    logger.info("")
    return logger, training_stats


def restart_from_checkpoint(ckp_paths, run_variables=None, **kwargs):
    """
    Re-start from checkpoint
    """
    # look for a checkpoint in exp repository
    if isinstance(ckp_paths, list):
        for ckp_path in ckp_paths:
            if os.path.isfile(ckp_path):
                break
    else:
        ckp_path = ckp_paths

    if not os.path.isfile(ckp_path):
        return

    logger.info("Found checkpoint at {}".format(ckp_path))

    # open checkpoint file
    checkpoint = torch.load(
        ckp_path, map_location="cuda:" + str(torch.distributed.get_rank() % torch.cuda.device_count())
    )

    # key is what to look for in the checkpoint file
    # value is the object to load
    # example: {'state_dict': model}
    for key, value in kwargs.items():
        if key in checkpoint and value is not None:
            try:
                msg = value.load_state_dict(checkpoint[key], strict=False)
                print(msg)
            except TypeError:
                msg = value.load_state_dict(checkpoint[key])
            logger.info("=> loaded {} from checkpoint '{}'".format(key, ckp_path))
        else:
            logger.warning(
                "=> failed to load {} from checkpoint '{}'".format(key, ckp_path)
            )

    # re load variable important for the run
    if run_variables is not None:
        for var_name in run_variables:
            if var_name in checkpoint:
                run_variables[var_name] = checkpoint[var_name]


def fix_random_seeds(seed=31):
    """
    Fix random seeds.
    """
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)


class AverageMeter(object):
    """computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


# Taken from https://sumit-ghosh.com/articles/parsing-dictionary-key-value-pairs-kwargs-argparse-python/
class ParseKwargs(argparse.Action):
    def __call__(self, parser, namespace, values, option_string=None):
        setattr(namespace, self.dest, dict())
        for value in values:
            key, value_str = value.split('=')
            if value_str.replace('-', '').isnumeric():
                processed_val = int(value_str)
            elif value_str.replace('-', '').replace('.', '').isnumeric():
                processed_val = float(value_str)
            elif value_str in ['True', 'true']:
                processed_val = True
            elif value_str in ['False', 'false']:
                processed_val = False
            else:
                processed_val = value_str
            getattr(namespace, self.dest)[key] = processed_val


def save_plot(df, key_name, plot_name, save_folder):
    '''
    Saves a plot of a particular statistic provided by the dataframe.

    Parameters
    ----------
    df : pandas.DataFrame
        DataFrame whose columns are statistics from an experiment and whose
        rows are epochs.
    key_name : str
        Column name in the DataFrame for the desired statistic to plot.
    plot_name : str
        Name to use for plot title, y-axis, and filename.
    save_folder : Union[str, pathlib.Path]
        Directory to save plots.
    '''
    if key_name in df:
        ax = df[key_name].plot()
        ax.set_title(plot_name)
        ax.set_xlabel('Epoch')
        ax.set_ylabel(plot_name)
        filename = f'{plot_name}.png'
        ax.get_figure().savefig(pathlib.Path(save_folder) / filename)


def plot_experiment(log_dir):
    '''
    Plots some statistics from the specified experiment and saves the plots.

    Parameters
    ----------
    log_dir : Union[str, pathlib.Path]
        Path containing the results of the experiment. Should have files of the
        form stats*.pkl.
    '''
    log_dir = pathlib.Path(log_dir)
    df_list = []
    for filepath in log_dir.iterdir():
        filename = str(filepath.name)
        if filename.startswith('stats') and filename.endswith('.pkl'):
            with open(filepath, 'rb') as open_file:
                df_list.append(pickle.load(open_file))
    avg_df = sum(df_list) / len(df_list)

    STAT_NAMES = [
        ('loss', 'Training Loss'),
        ('prec1', 'Training Accuracy'),
        ('prec1_val', 'Source Validation Accuracy'),
        ('prec1_tgt', 'Target Accuracy')
    ]
    for stat in STAT_NAMES:
        save_plot(avg_df, stat[0], stat[1], log_dir)
        plt.close()


def populate_defaults_for_swav(config):
    """
    Populate defaults for SwAV pretraining.
    """
    assert config.dataset is not None, 'dataset must be specified'
    config = populate_config(config, DATASET_DEFAULTS[config.dataset])

    # Sanity checks
    assert config.warmup_epochs < config.n_epochs, \
        f'The number of warmup_epochs ({config.warmup_epochs}) cannot be greater than n_epochs ({config.n_epochs}).'

    return config