import datetime
import logging
import os
import shutil
import sys
from pathlib import Path
from typing import Optional

import numpy as np
import torch

import utils


def setup_dir(args, config, debug=False):
    results_dir = [config.dirs['experiments'], args.model, args.dset_name]
    if hasattr(args, 'version'):
        results_dir += [args.version]

    results_dir = os.path.join(*results_dir)
    for n in ['new', 'debug', 'paper']:
        dir_ = os.path.join(results_dir, n)
        os.makedirs(dir_, exist_ok=True)
    new_dir = os.path.join(results_dir, 'new')
    os.makedirs(new_dir, exist_ok=True)
    debug_dir = os.path.join(results_dir, 'debug')
    os.makedirs(debug_dir, exist_ok=True)

    run_dir = debug_dir if debug else new_dir
    run_dir = os.path.join(run_dir, args.exp_name, args.trial)
    run_id = datetime.datetime.now().isoformat()
    # ':' not allowed in macos, '/' not allowed in OneDrive
    run_id = run_id.replace(':', '-')
    args.run_id = run_id.replace('T', '_T_')
    run_path = os.path.join(run_dir, args.run_id)
    os.makedirs(run_path, exist_ok=False)

    torch.save(args, os.path.join(run_path + '/args.pt'))
    # save code (assumes that this file is located two levels below project
    # root)
    code_dir = Path(__file__).parents[1]
    shutil.make_archive(run_path + '/code', 'zip', str(code_dir))

    return args, run_path


def set_logger(verbosity: int, log_path: Optional[str] = None):
    logging.basicConfig(format='%(message)s',
                        handlers=[logging.StreamHandler(sys.stdout)])
    logger = logging.getLogger('custom')
    if log_path:
        logger.addHandler(logging.FileHandler(log_path))
    logger.setLevel(verbosity)
    return logger


def close_logger():
    logger = logging.getLogger('custom')
    handlers = logger.handlers[:]
    for handler in handlers:
        handler.close()
        logger.removeHandler(handler)


class AverageMeter(object):
    def __init__(self):
        self._sum, self._count = 0, 0

    def reset(self):
        self._sum, self._count = 0, 0

    def update(self, val):
        self._sum += val
        self._count += 1

    @property
    def avg(self):
        if self._count == 0:
            return 0.
        return self._sum / self._count


class Meter:
    """ Computes average over iterations. """

    def __init__(self):
        self.data = utils.rec_defaultdict()

    def add_dict(self, log: dict, name: Optional[str] = None):
        """
        :param log: dictionary of arbitrary depth, leaves can be scalars or
        lists
        :param name: what kind of logs to save
        """
        if not log:
            return
        data = self.data[name] if name else self.data
        self._add_dict(log, data)

    def _add_dict(self, log: dict, data: dict):
        for k, v in log.items():
            if isinstance(v, dict):
                self._add_dict(v, data[k])

            # list of leaf nodes
            elif isinstance(v, list):
                v = [cur_v for cur_v in v if cur_v is not None]
                if k not in data.keys():
                    data[k] = [AverageMeter() for _ in v]
                for u, meter in zip(v, data[k]):
                    msg = 'List entries must be leaf nodes to improve ' \
                          'clarity. You could, for example, use several ' \
                          'variables towards complying with this constraint.'
                    assert np.isscalar(u), msg
                    meter.update(u)

            # scalar leaf nodes
            elif v is not None:
                if k not in data.keys():
                    data[k] = AverageMeter()
                data[k].update(v)

    def _flush(self, data, log):
        """
        :param data: self.data at given level
        :param log: log containing averages
        :return:
        """
        for k, v in data.items():
            if isinstance(v, dict):
                log[k] = self._flush(v, log[k])
            elif isinstance(v, list):
                log[k] = [cur.avg for cur in v]
            else:
                log[k] = v.avg
        return log

    def flush(self):
        """
        :return: average for every entry in self.data
        """
        log = utils.rec_defaultdict()
        return self._flush(self.data, log)
