import glob
import os
import shutil
import collections
import timeit
import random

import numpy as np
import torch

from gym.envs.registration import make as gym_make
from .make_agent import make_agent
from .filewriter import FileWriter
from envs.wrappers import ParallelVecEnv, VecMonitor, VecNormalize, \
    VecPreprocessImageWrapper, TimeLimit, MultiGridFullyObsWrapper, Seedable
from plr import LevelSampler, VecPLRWrapper


class DotDict(dict):
    __getattr__ = dict.__getitem__
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

    def __init__(self, dct):
        for key, value in dct.items():
            if hasattr(value, 'keys'):
                value = DotDict(value)
            self[key] = value

    def __getstate__(self):
        return self

    def __setstate__(self, state):
        self.update(state)
        self.__dict__ = self

def cprint(condition, *args, **kwargs):
    if condition:
        print(*args, **kwargs)

def flatten_suffix_dict(d, suffix):
    result = {}

    if isinstance(d, dict):
        d = [d,]

    # If list of dicts, then average key values, 
    # assuming matching keys
    if isinstance(d, (list, tuple)):
        if len(d) == 0:
            return result

        if isinstance(d[0], dict):
            suffix_str = f'_{suffix}' if suffix else ''
            result.update({f'{k}{suffix_str}':np.mean([_d[k] for _d in d]) for k in d[0].keys()})
        else:
            result.update({suffix:np.mean([v for v in d])})

    return result

def init(module, weight_init, bias_init, gain=1):
    weight_init(module.weight.data, gain=gain)
    bias_init(module.bias.data)
    return module

def safe_checkpoint(state_dict, path, index=None, archive_interval=None):
    filename, ext = os.path.splitext(path)
    path_tmp = f'{filename}_tmp{ext}'
    torch.save(state_dict, path_tmp)

    os.replace(path_tmp, path)

    if index is not None and archive_interval is not None and archive_interval > 0:
        if index % archive_interval == 0:
            archive_path = f'{filename}_{index}{ext}'
            shutil.copy(path, archive_path)

def cleanup_log_dir(log_dir, pattern='*'):
    try:
        os.makedirs(log_dir)
    except OSError:
        files = glob.glob(os.path.join(log_dir, pattern))
        for f in files:
            os.remove(f)

def seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def str2bool(v):
    if isinstance(v, bool):
       return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

def get_obs_at_index(obs, i):
    if isinstance(obs, dict):
        return {k: obs[k][i] for k in obs.keys()}
    else:
        return obs[i]

def set_obs_at_index(obs, obs_, i):
    if isinstance(obs, dict):
        for k in obs.keys():
            obs[k][i] = obs_[k].squeeze(0)
    else:
        obs[i] = obs_[0].squeeze(0)

def is_discrete_actions(env, adversary=False):
    if adversary:
        return env.adversary_action_space.__class__.__name__ == 'Discrete'
    else:
        return env.action_space.__class__.__name__ == 'Discrete'

def _make_env(args):
    env_kwargs = {}

    is_multigrid = args.env_name.startswith('MultiGrid')
    is_minihack = args.env_name.startswith('MiniHack')

    env_kwargs.update({
        'seed': args.seed,
        'p': eval(args.p),
        'obl_correction': args.force_obl_correction,
        'use_learned_beliefs': args.use_learned_beliefs
    })

    if is_minihack:
        env_kwargs.update({
            'observation_keys':("glyphs", "blstats", "message"),
            'fully_observable': args.fully_observable,
        })

    if args.singleton_env or args.use_plr:
        env_kwargs.update({
            'fixed_environment': True})

    if 'MultiRoomBC' in args.env_name or 'BinaryChoice' in args.env_name:
        env_kwargs.update({
                'rewards': eval(args.stochastic_choice_rewards),
                'reward_spreads': eval(args.stochastic_choice_reward_spreads),
            })

    if is_multigrid:
        if 'BinaryChoice' in args.env_name:
            env_kwargs.update({
                'use_walls': args.stochastic_choice_use_walls,
            })

    if is_minihack:
        if 'BinaryChoice' in args.env_name:
            env_kwargs.update({
                'reward_dist': args.reward_dist,
                'goal_hint_p': args.goal_hint_p,
            })
        if 'WeaponChoice' in args.env_name:
            env_kwargs.update({
                'reward_dist': args.reward_dist,
            })

    env = gym_make(args.env_name, **env_kwargs)

    # Apply env-family wrappers
    if is_multigrid:
        if args.fully_observable:
            env = MultiGridFullyObsWrapper(env)
    elif is_minihack:
        env = Seedable(env)
        env.seed(args.seed)

    return env

def create_parallel_env(args, adversary=True):
    is_multigrid = args.env_name.startswith('MultiGrid')
    is_minihack = args.env_name.startswith('MiniHack')

    make_fn = lambda: _make_env(args)

    venv = ParallelVecEnv([make_fn]*args.num_processes, adversary=adversary)

    if args.use_plr:
        plr_args = make_plr_args(args, venv.observation_space, venv.action_space)
        level_sampler = LevelSampler(**plr_args)
        venv = VecPLRWrapper(venv, level_sampler=level_sampler)

    venv = VecMonitor(venv=venv, filename=None, keep_buf=100)
    venv = VecNormalize(venv=venv, ob=False, ret=args.normalize_returns)

    obs_key = None
    scale = None
    transpose_order = None

    if is_multigrid:
        obs_key = 'image'
        scale = 10.0
        transpose_order = [2,0,1]

    venv = VecPreprocessImageWrapper(venv=venv, obs_key=obs_key, 
        transpose_order=transpose_order, scale=scale)

    if args.singleton_env:
        seeds = [args.seed]*args.num_processes
    else:
        seeds = [i for i in range(args.num_processes)]

    for i, seed in enumerate(seeds):
        venv.seed(seed, i)

    return venv

def make_plr_args(args, obs_space, action_space):
    return dict( 
        seeds=[], 
        obs_space=obs_space, 
        action_space=action_space, 
        num_actors=args.num_processes,
        strategy=args.level_replay_strategy,
        replay_schedule=args.level_replay_schedule,
        score_transform=args.level_replay_score_transform,
        temperature=args.level_replay_temperature,
        eps=args.level_replay_eps,
        rho=args.level_replay_rho,
        replay_prob=args.level_replay_prob, 
        alpha=args.level_replay_alpha,
        staleness_coef=args.staleness_coef,
        staleness_transform=args.staleness_transform,
        staleness_temperature=args.staleness_temperature,
        sample_full_distribution=args.train_full_distribution,
        seed_buffer_size=args.level_replay_seed_buffer_size,
        seed_buffer_priority=args.level_replay_seed_buffer_priority
    )
