import numpy as np
import torch
import random
import yaml
from copy import deepcopy
from munch import Munch

def random_sample_configs(config, max_search=20000):
    def recursion_sample(cfg):
        tmp_cfg = deepcopy(cfg)
        for k in cfg.keys():
            if isinstance(cfg[k], list):
                tmp_cfg[k] = np.random.choice(cfg[k])
            elif isinstance(cfg[k], Munch):
                tmp_cfg[k] = recursion_sample(cfg[k])
            else:
                tmp_cfg[k] = cfg[k]
        return tmp_cfg

    search_configs = []
    for i in range(max_search):
        tmp_cfg = recursion_sample(config)
        tmp_cfg = Munch.fromDict(tmp_cfg)
        if tmp_cfg not in search_configs:
            search_configs.append(tmp_cfg)
    print(search_configs)
    return search_configs

def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    import random
    random.seed(seed)
    torch.cuda.manual_seed(seed)

def get_config(config):
    with open(config, 'r') as stream:
        return yaml.safe_load(stream)

def save_model(parent_net, child_net, parent_optim, child_optim, PATH):
    def f():
        torch.save(
        {
            "parent": parent_net.state_dict(),
            "parent_optimizer": parent_optim.state_dict(),
            "child": child_net.state_dict(),
            "child_optimizer": child_optim.state_dict(),
        },
        PATH)
    return f
def load_model(parent_net, child_net, parent_optim, child_optim, PATH):
    def f():
        print('[*] loading models from %s' % PATH)
        checkpoint = torch.load(PATH)
        parent_net.load_state_dict(checkpoint["parent"])
        parent_optim.load_state_dict(checkpoint["parent_optimizer"])
        child_net.load_state_dict(checkpoint["child"])
        child_optim.load_state_dict(checkpoint["child_optimizer"])
    return f

import pickle
def save_dict(di_, filename_):
    with open(filename_, 'wb') as f:
        pickle.dump(di_, f)

def load_dict(filename_):
    with open(filename_, 'rb') as f:
        ret_di = pickle.load(f)
    return ret_di

# code borrowed from https://raw.githubusercontent.com/nicola-decao/BNAF/master/optim/lr_scheduler.py

import math
import torch


class PolyakAdam(torch.optim.Optimizer):
    def __init__(
            self,
            params,
            lr=1e-3,
            betas=(0.9, 0.999),
            eps=1e-8,
            weight_decay=0,
            amsgrad=True,
            polyak=0.998,
    ):
        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {}".format(eps))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
        if not 0.0 <= polyak <= 1.0:
            raise ValueError("Invalid polyak decay term: {}".format(polyak))

        defaults = dict(
            lr=lr,
            betas=betas,
            eps=eps,
            weight_decay=weight_decay,
            amsgrad=amsgrad,
            polyak=polyak,
        )
        super(PolyakAdam, self).__init__(params, defaults)

    def __setstate__(self, state):
        super(PolyakAdam, self).__setstate__(state)
        for group in self.param_groups:
            group.setdefault("amsgrad", False)

    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue
                grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError(
                        "Adam does not support sparse gradients, please consider SparseAdam instead"
                    )
                amsgrad = group["amsgrad"]

                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state["step"] = 0
                    # Exponential moving average of gradient values
                    state["exp_avg"] = torch.zeros_like(p.data)
                    # Exponential moving average of squared gradient values
                    state["exp_avg_sq"] = torch.zeros_like(p.data)
                    # Exponential moving average of param
                    state["exp_avg_param"] = torch.zeros_like(p.data)
                    if amsgrad:
                        # Maintains max of all exp. moving avg. of sq. grad. values
                        state["max_exp_avg_sq"] = torch.zeros_like(p.data)

                exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
                if amsgrad:
                    max_exp_avg_sq = state["max_exp_avg_sq"]
                beta1, beta2 = group["betas"]

                state["step"] += 1

                if group["weight_decay"] != 0:
                    grad.add_(group["weight_decay"], p.data)

                # Decay the first and second moment running average coefficient
                exp_avg.mul_(beta1).add_(1 - beta1, grad)
                exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
                if amsgrad:
                    # Maintains the maximum of all 2nd moment running avg. till now
                    torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
                    # Use the max. for normalizing running avg. of gradient
                    denom = max_exp_avg_sq.sqrt().add_(group["eps"])
                else:
                    denom = exp_avg_sq.sqrt().add_(group["eps"])

                bias_correction1 = 1 - beta1 ** state["step"]
                bias_correction2 = 1 - beta2 ** state["step"]
                step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1

                p.data.addcdiv_(-step_size, exp_avg, denom)

                polyak = self.defaults["polyak"]
                state["exp_avg_param"] = (
                        polyak * state["exp_avg_param"] + (1 - polyak) * p.data
                )

        return loss

    def swap(self):
        """
        Swapping the running average of params and the current params for saving parameters using polyak averaging
        """
        for group in self.param_groups:
            for p in group["params"]:
                state = self.state[p]
                new = p.data
                p.data = state["exp_avg_param"]
                state["exp_avg_param"] = new

    def substitute(self):
        for group in self.param_groups:
            for p in group["params"]:
                p.data = self.state[p]["exp_avg_param"]



import torch


class ReduceLROnPlateau(torch.optim.lr_scheduler.ReduceLROnPlateau):
    def __init__(self, *args, early_stopping=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.early_stopping = early_stopping
        self.early_stopping_counter = 0

    def step(self, metrics, epoch=None, callback_best=None, callback_reduce=None):
        current = metrics
        if epoch is None:
            epoch = self.last_epoch = self.last_epoch + 1
        self.last_epoch = epoch

        if self.is_better(current, self.best):
            self.best = current
            self.num_bad_epochs = 0
            self.early_stopping_counter = 0
            if callback_best is not None:
                callback_best()
        else:
            self.num_bad_epochs += 1
            self.early_stopping_counter += 1

        if self.in_cooldown:
            self.cooldown_counter -= 1
            self.num_bad_epochs = 0  # ignore any bad epochs in cooldown

        if self.num_bad_epochs > self.patience:
            if callback_reduce is not None:
                callback_reduce()
            self._reduce_lr(epoch)
            self.cooldown_counter = self.cooldown
            self.num_bad_epochs = 0

        return self.early_stopping_counter == self.early_stopping

import torch.nn.init as init
def weights_init(init_type='gaussian', gain=0.02):
    def init_fun(m):
        classname = m.__class__.__name__
        if (classname.find('Conv') == 0 or classname.find('Linear') == 0 or classname.find('Embedding')) and hasattr(m, 'weight') and ('Norm' not in classname):
            # print m.__class__.__name__
            if init_type == 'gaussian':
                init.normal_(m.weight.data, 0.0,  gain)
            elif init_type == 'xavier':
                init.xavier_normal_(m.weight.data, gain=gain)
            elif init_type == 'xavier_uniform':
                init.xavier_uniform_(m.weight.data)
            elif init_type == 'kaiming':
                init.kaiming_uniform_(m.weight.data)
            elif init_type == 'orthogonal':
                init.orthogonal_(m.weight.data, gain=math.sqrt(2))
            elif init_type == 'default':
                pass
            else:
                assert 0, "Unsupported initialization: {}".format(init_type)

        elif classname.find(
                'Norm') != -1:  # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
            init.normal_(m.weight.data, 1.0, 0.02)
            init.constant_(m.bias.data, 0.0)
    return init_fun

class InfiniteIterator:
    def __init__(self, loader):
        self.loader = loader

    def __next__(self):
        try:
            out = next(self.iter)
        except:
            self.iter = iter(self.loader)
            out = next(self.iter)
        return out