"""Various utilities."""

import socket
import time
import datetime

import os
import sys
import shutil
import glob
import logging

import csv

import torch
import random
import numpy as np
from ConfigSpace.read_and_write import json as config_space_json_r_w
from .hyperparameters import Hyperparameters

from .consts import DEFAULT_SETUP, BENCHMARK, SHARING_STRATEGY, DISTRIBUTED_BACKEND
torch.backends.cudnn.benchmark = BENCHMARK
torch.multiprocessing.set_sharing_strategy(SHARING_STRATEGY)


def system_startup(args=None, defs=None):
    """Decide and print GPU / CPU / hostname info. Generate distributed setting if running in dist. mode."""
    if getattr(args, 'local_rank', None) is None:  # not in distributed mode!
        is_distributed = False
        device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
        setup = dict(device=device, dtype=DEFAULT_SETUP['dtype'],
                     memory_format=DEFAULT_SETUP['memory_format'])
        print('Currently evaluating -------------------------------:')
        print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p"))
        if args is not None:
            print(args)
        if defs is not None:
            print(repr(defs))
        print(
            f'CPUs: {torch.get_num_threads()}, GPUs: {torch.cuda.device_count()} on {socket.gethostname()}.')

        if torch.cuda.is_available():
            print(f'GPU : {torch.cuda.get_device_name(device=device)}')

    else:
        is_distributed = True
        if args.local_rank > torch.cuda.device_count():
            raise ValueError(
                'Process invalid, oversubscribing to GPUs is not possible in this mode.')
        else:
            torch.cuda.set_device(args.local_rank)
            device = torch.device(f'cuda:{args.local_rank}')
        setup = dict(device=device, dtype=DEFAULT_SETUP['dtype'],
                     memory_format=DEFAULT_SETUP['memory_format'])
        torch.distributed.init_process_group(backend=DISTRIBUTED_BACKEND, init_method='env://')

        world_size = torch.distributed.get_world_size()
        if torch.distributed.get_rank() == 0:
            print('Currently evaluating -------------------------------:')
            print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p"))
            if args is not None:
                print(args)
            if defs is not None:
                print(repr(defs))
            print(
                f'CPUs: {torch.get_num_threads()}, GPUs: {torch.cuda.device_count()} on {socket.gethostname()}')
            print(f'Distributed mode launched on {world_size} GPUs'
                  f' with backend {DISTRIBUTED_BACKEND} with sharing strategy {SHARING_STRATEGY}')

        if torch.cuda.is_available():
            print(f'GPU : {torch.cuda.get_device_name(device=device)}')

    return setup, is_distributed


def get_config(args):
    config_path = args.config_path
    try:
        cs = config_space_json_r_w.read(open(config_path, 'r').read())
    except FileNotFoundError:
        # fallback to default config
        try:
            config_file = f'configs/darts_{args.dataset}_configspace_all_{str(args.noise_level).split(".")[1]}.json'
            cs = config_space_json_r_w.read(open(config_file, 'r').read())
            print(
                f'Could not find {config_path}. Falling back to config for operator "all" with given dataset and noise level.!')
        except FileNotFoundError:
            # cs = config_space_json_r_w.read(open(f'configs/darts_1_configspace_all_50.json', 'r').read())
            cs = config_space_json_r_w.read(open('configs/DARTS_1_BSDS-Denoising_configspace_all_10.json', 'r').read())
            print(f'Could not find {config_path}. Falling back to default config.!')
    config = cs.get_default_configuration().get_dictionary()

    def _maybe_overwrite(key, args, config):
        if getattr(args, key) is not None:
            config[key] = getattr(args, key)

    for param in ['alpha_scheduler', 'alpha_optimizer', 'alpha_lr', 'alpha_warmup', 'param_lr', 'param_warmup',
                  'param_weight_decay', 'alpha_weight_decay', 'epochs']:
        _maybe_overwrite(param, args, config)
    if args.bohb:
        print(config)
    return config


def load_hyperparameters(args, config):
    common_defs = dict(epochs=config['epochs'], param_optimizer='Adam',
                       param_scheduler='linear', dryrun=args.dryrun)
    # Define a list of hyperparameter variants that are to be tested
    variants = []
    if args.variant == 'DARTS_1' or args.test_all_variants:
        variants += [Hyperparameters(name='DARTS-1th', **common_defs, update='liu', norm='none',
                                     alpha_scheduler=config['alpha_scheduler'], mirror_descent=False,
                                     project=False, decay_entropy=False,
                                     fuse=False, alpha_optimizer=config['alpha_optimizer'], alpha_lr=config['alpha_lr'],
                                     alpha_warmup=config['alpha_warmup'], param_lr=config['param_lr'],
                                     param_weight_decay=config['param_weight_decay'],
                                     alpha_weight_decay=config['alpha_weight_decay'],
                                     param_warmup=config['param_warmup'])]
    elif args.variant == 'DARTS_1_single' or args.test_all_variants:
        variants += [Hyperparameters(name='DARTS-1th-single', **common_defs, update='liu-single', norm='none',
                                     alpha_scheduler=config['alpha_scheduler'], mirror_descent=False,
                                     project=False, decay_entropy=False,
                                     fuse=False, alpha_optimizer=config['alpha_optimizer'], alpha_lr=config['alpha_lr'],
                                     alpha_warmup=config['alpha_warmup'], param_lr=config['param_lr'],
                                     param_weight_decay=config['param_weight_decay'],
                                     alpha_weight_decay=config['alpha_weight_decay'],
                                     param_warmup=config['param_warmup'])]
    elif args.variant == 'DARTS_norm' or args.test_all_variants:
        variants += [Hyperparameters(name='DARTS-norm', **common_defs, update='liu', norm='DARTS',
                                     alpha_scheduler=config['alpha_scheduler'], mirror_descent=False,
                                     project=False, decay_entropy=False,
                                     fuse=False, alpha_optimizer=config['alpha_optimizer'], alpha_lr=config['alpha_lr'],
                                     alpha_warmup=config['alpha_warmup'], param_lr=config['param_lr'],
                                     param_weight_decay=config['param_weight_decay'],
                                     alpha_weight_decay=config['alpha_weight_decay'],
                                     param_warmup=config['param_warmup'])]
    elif args.variant == 'TemperatureDecay' or args.test_all_variants:
        variants += [Hyperparameters(name='TemperatureDecay', **common_defs, update='liu', norm='none',
                                     alpha_scheduler=config['alpha_scheduler'], mirror_descent=False,
                                     project=False, decay_entropy=True,
                                     fuse=False, alpha_optimizer=config['alpha_optimizer'], alpha_lr=config['alpha_lr'],
                                     alpha_warmup=config['alpha_warmup'], param_lr=config['param_lr'],
                                     param_weight_decay=config['param_weight_decay'],
                                     alpha_weight_decay=config['alpha_weight_decay'],
                                     param_warmup=config['param_warmup'])]
    elif args.variant == 'TemperatureDecay_norm' or args.test_all_variants:
        variants += [Hyperparameters(name='TemperatureDecay-norm', **common_defs, update='liu', norm='DARTS',
                                     alpha_scheduler=config['alpha_scheduler'], mirror_descent=False,
                                     project=False, decay_entropy=True,
                                     fuse=False, alpha_optimizer=config['alpha_optimizer'], alpha_lr=config['alpha_lr'],
                                     alpha_warmup=config['alpha_warmup'], param_lr=config['param_lr'],
                                     param_weight_decay=config['param_weight_decay'],
                                     alpha_weight_decay=config['alpha_weight_decay'],
                                     param_warmup=config['param_warmup'])]
    elif args.variant == 'TemperatureProject' or args.test_all_variants:
        variants += [Hyperparameters(name='TemperatureProject', **common_defs, update='liu', norm='none',
                                     alpha_scheduler=config['alpha_scheduler'], mirror_descent=False,
                                     project=True, decay_entropy=True,
                                     fuse=False, alpha_optimizer=config['alpha_optimizer'], alpha_lr=config['alpha_lr'],
                                     alpha_warmup=config['alpha_warmup'], param_lr=config['param_lr'],
                                     param_weight_decay=config['param_weight_decay'],
                                     alpha_weight_decay=config['alpha_weight_decay'],
                                     param_warmup=config['param_warmup'])]
    elif args.variant == 'MirrorAdam' or args.test_all_variants:
        variants += [Hyperparameters(name='MirrorAdam', **common_defs, update='liu', norm='none',
                                     alpha_scheduler=config['alpha_scheduler'], mirror_descent=True,
                                     project=False, decay_entropy=True,
                                     fuse=False, alpha_optimizer=config['alpha_optimizer'], alpha_lr=config['alpha_lr'],
                                     alpha_warmup=config['alpha_warmup'], param_lr=config['param_lr'],
                                     param_weight_decay=config['param_weight_decay'],
                                     alpha_weight_decay=config['alpha_weight_decay'],
                                     param_warmup=config['param_warmup'])]
    elif args.variant == 'MirrorDescent' or args.test_all_variants:
        variants += [Hyperparameters(name='MirrorDescent', **common_defs, update='liu', norm='none',
                                     alpha_scheduler=config['alpha_scheduler'], mirror_descent=True,
                                     project=False, decay_entropy=True,
                                     fuse=False, alpha_optimizer=config['alpha_optimizer'], alpha_lr=config['alpha_lr'],
                                     alpha_warmup=config['alpha_warmup'], param_lr=config['param_lr'],
                                     param_weight_decay=config['param_weight_decay'],
                                     alpha_weight_decay=config['alpha_weight_decay'],
                                     param_warmup=config['param_warmup'])]
    elif args.variant == 'MirrorDescent_norm' or args.test_all_variants:
        variants += [Hyperparameters(name='MirrorDescent-norm', **common_defs, update='liu', norm='DARTS',
                                     alpha_scheduler=config['alpha_scheduler'], mirror_descent=True,
                                     project=False, decay_entropy=True,
                                     fuse=False, alpha_optimizer=config['alpha_optimizer'], alpha_lr=config['alpha_lr'],
                                     alpha_warmup=config['alpha_warmup'], param_lr=config['param_lr'],
                                     param_weight_decay=config['param_weight_decay'],
                                     alpha_weight_decay=config['alpha_weight_decay'],
                                     param_warmup=config['param_warmup'])]
    else:
        raise TypeError("Unknow Search Variant : {:}".format(args.variant))
    return variants


def create_exp_dir(path, scripts_to_save=None):
    os.makedirs(path, exist_ok=True)
    print('Experiment dir : {}'.format(path))

    if scripts_to_save is not None:
        os.makedirs(os.path.join(path, 'scripts'), exist_ok=True)
        for script in scripts_to_save:
            dst_file = os.path.join(path, 'scripts', os.path.basename(script))
            shutil.copyfile(script, dst_file)


def save_to_table(out_dir, table_name, dryrun, **kwargs):
    """Save keys to .csv files. Function adapted from Micah."""
    # Check for file
    if not os.path.isdir(out_dir):
        os.makedirs(out_dir)
    fname = os.path.join(out_dir, f'table_{table_name}.csv')
    fieldnames = list(kwargs.keys())

    # Read or write header
    try:
        with open(fname, 'r') as f:
            reader = csv.reader(f, delimiter='\t')
            header = [line for line in reader][0]  # noqa
    except Exception as e:  # noqa
        if not dryrun:
            print('Creating a new .csv table...')
            with open(fname, 'w') as f:
                writer = csv.DictWriter(f, delimiter='\t', fieldnames=fieldnames)
                writer.writeheader()
        else:
            print(f'Would create new .csv table {fname}.')
    if not dryrun:
        # Add row for this experiment
        with open(fname, 'a') as f:
            writer = csv.DictWriter(f, delimiter='\t', fieldnames=fieldnames)
            writer.writerow(kwargs)
        print('\nResults saved to ' + fname + '.')
    else:
        print(f'Would save results to {fname}.')


def start_logging(save_dir, dryrun=False, save_runs=True):
    save_dir = f'search/search-{save_dir}-{time.strftime("%Y%m%d-%H%M%S")}'
    log_format = '%(asctime)s %(message)s'
    logging.basicConfig(stream=sys.stdout, level=logging.INFO,
                        format=log_format, datefmt='%m/%d %I:%M:%S %p')
    if not dryrun and save_runs:
        create_exp_dir(save_dir, scripts_to_save=glob.glob('*.py'))
        fh = logging.FileHandler(os.path.join(save_dir, 'log.txt'))
        fh.setFormatter(logging.Formatter(log_format))
        logging.getLogger().addHandler(fh)
    return save_dir


def set_random_seed(seed=233):
    """."""
    torch.manual_seed(seed + 1)
    torch.cuda.manual_seed(seed + 2)
    torch.cuda.manual_seed_all(seed + 3)
    np.random.seed(seed + 4)
    torch.cuda.manual_seed_all(seed + 5)
    random.seed(seed + 6)


def set_deterministic():
    """Switch pytorch into a deterministic computation mode."""
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


class DistributedDataParallel(torch.nn.parallel.DistributedDataParallel):

    def __init__(self, *args, **kwargs):
        """Synchronize alphas manually."""
        super().__init__(*args, **kwargs)
        for alpha in self.arch_parameters():
            torch.distributed.broadcast(alpha, 0)

    def __getattr__(self, name):
        """Overwrite calls to unknown attributes with module attributes.

        We do this at our own risk :>
        """
        try:
            return super().__getattr__(name)
        except AttributeError:
            return getattr(self.module, name)


def psnr_compute(img_batch, ref_batch, batched=False, factor=1.0, clip=False):
    """Standard PSNR."""
    if clip:
        img_batch = torch.clamp(img_batch, 0, 1)

    if batched:
        mse = ((img_batch.detach() - ref_batch)**2).mean()
        if mse > 0 and torch.isfinite(mse):
            return (10 * torch.log10(factor**2 / mse))
        elif not torch.isfinite(mse):
            return torch.tensor(float('nan'), device=img_batch.device)
        else:
            return torch.tensor(float('inf'), device=img_batch.device)
    else:
        B = img_batch.shape[0]
        mse_per_example = ((img_batch.detach() - ref_batch)**2).view(B, -1).mean(dim=1)
        if any(mse_per_example == 0):
            return torch.tensor(float('inf'), device=img_batch.device)
        elif not all(torch.isfinite(mse_per_example)):
            return torch.tensor(float('nan'), device=img_batch.device)
        else:
            return (10 * torch.log10(factor**2 / mse_per_example)).mean()
