"""Differentiable optimizers (SGD, RMSprop, Adam).
"""
import sys
import pdb
import copy

import torch
import torch.nn.functional as F
from torch.autograd import Variable

import numpy as np

# Local imports
import utils

required = object()


class MySGDMwd:
  def __init__(self, model, optim_parameters_fn=None, lr=0.001, momentum=0, wdecay=0.0, cuda=False):
    self.cuda = cuda

    self.defaults = {
        'lr': np.log(lr),
        'momentum': momentum,
        'wdecay': wdecay
    }
    self.m = []

    if optim_parameters_fn is None:
      def optim_parameters_fn(model):
        return model.parameters()

    self.optim_parameters_fn = optim_parameters_fn
    params = optim_parameters_fn(model)

    self.param_groups = []

    param_groups = list(params)
    if not isinstance(param_groups[0], dict):
      param_groups = [{'params': param_groups}]

    for param_group in param_groups:
      self.add_param_group(param_group)
      m_tmp = []
      for param in param_group['params']:
        m_tmp += [Variable(torch.zeros(param.size(), device='cuda' if cuda else 'cpu', dtype=torch.float32))]

      self.m += [m_tmp]

    self.model = model

  def parameters(self):
    return [group['lr'] for group in self.param_groups]

  def parameters_with_names(self, apply_transforms=True):
    param_list = []
    for group in self.param_groups:
      d = {}
      for key in ['lr']:
        d[key] = torch.exp(group[key])
      param_list.append(d)
    return param_list

  def update(self, take_step=False):
    if take_step:
      for group, m in zip(self.param_groups, self.m):
        lr = group['lr']
        momentum = group['momentum'].item()
        wdecay = group['wdecay']
        for param, m_i in zip(group['params'], m):
          d_p = param.grad.data
          d_p.add_(wdecay.item(), param.data)
          m_i.data.mul_(momentum).add_(d_p)

          delta = -torch.exp(lr) * (momentum * m_i.data + d_p)
          param.data = param.data + delta
    else:
      updated_model = copy.deepcopy(self.model)
      for param in updated_model.parameters():
        param.requires_grad = False

      updated_params = self.optim_parameters_fn(updated_model)
      updated_param_groups = list(updated_params)
      if not isinstance(updated_param_groups[0], dict):
        updated_param_groups = [{'params': updated_param_groups}]

      for group, updated_group, m in zip(self.param_groups, updated_param_groups, self.m):
        lr = group['lr']
        momentum = group['momentum']
        wdecay = group['wdecay']
        for param, updated_param, m_i in zip(group['params'], updated_group['params'], m):
          d_p = param.grad.data
          d_p.add_(wdecay.item(), param.data)
          tmp = m_i.data * momentum + d_p
          tmp.detach()
          delta = -torch.exp(lr) * (tmp * momentum + d_p)
          updated_param += delta

      return updated_model

  def add_param_group(self, param_group):
    r"""Add a param group to the :class:`Optimizer` s `param_groups`.
    This can be useful when fine tuning a pre-trained network as frozen layers can be made
    trainable and added to the :class:`Optimizer` as training progresses.
    Arguments:
        param_group (dict): Specifies what Tensors should be optimized along with group
        specific optimization options.
    """
    params = param_group['params']
    if isinstance(params, torch.Tensor):
      param_group['params'] = [params]
    elif isinstance(params, set):
      raise TypeError('optimizer parameters need to be organized in ordered collections, but '
                      'the ordering of tensors in sets will change between runs. Please use a list instead.')
    else:
      param_group['params'] = list(params)

    for param in param_group['params']:
      if not isinstance(param, torch.Tensor):
        raise TypeError("optimizer can only optimize Tensors, "
                        "but one of the params is " + torch.typename(param))
      if not param.requires_grad:
        raise ValueError("optimizing a parameter that doesn't require gradients")
      if not param.is_leaf:
        raise ValueError("can't optimize a non-leaf Tensor")

    for name, default in self.defaults.items():
      if default is required and name not in param_group:
        raise ValueError("parameter group didn't specify a value of required optimization parameter " +
                         name)
      else:
        # param_group.setdefault(name, default)
        param_group.setdefault(name, torch.tensor([default], device='cuda' if self.cuda else 'cpu',
                                                  requires_grad=True, dtype=torch.float32))

    param_set = set()
    for group in self.param_groups:
      param_set.update(set(group['params']))

    if not param_set.isdisjoint(set(param_group['params'])):
      raise ValueError("some parameters appear in more than one parameter group")

    self.param_groups.append(param_group)

  def zero_grad(self):
    r"""Clears the gradients of all optimized :class:`torch.Tensor` s."""
    for group in self.param_groups:
      for p in group['params']:
        if p.grad is not None:
          p.grad.detach_()
          p.grad.zero_()

  def factor_lr(self, lr_factor):
    for group in self.param_groups:
      group['lr'] = torch.tensor(
          np.log(lr_factor * torch.exp(group['lr']).item()),
          device='cuda' if self.cuda else 'cpu',
          requires_grad=False)


class MyAdam:
  def __init__(self, model, optim_parameters_fn=None, lr=0.00001, rho=0.5,
               epsilon=1e-8, betas=(0.9, 0.99), weight_decay=0, tune=['lr'],
               device='cuda:0'):

    self.model = model

    if optim_parameters_fn is None:
      def optim_parameters_fn(model):
        return model.parameters()

    self.optim_parameters_fn = optim_parameters_fn
    self.tune = tune
    self.weight_decay = weight_decay
    self.device = device

    self.transforms = {
        'lr': torch.exp,
        'rho': F.sigmoid,
        'epsilon': torch.exp
    }

    self.inverse_transforms = {
        'lr': np.log,
        'rho': utils.logit,
        'epsilon': np.log
    }

    self.defaults = {
        'lr': self.inverse_transforms['lr'](lr),
        'rho': self.inverse_transforms['rho'](rho),
        'epsilon': self.inverse_transforms['epsilon'](epsilon),
    }

    self.beta1 = betas[0]
    self.beta2 = betas[1]

    params = optim_parameters_fn(model)

    self.param_groups = []
    param_groups = list(params)
    if not isinstance(param_groups[0], dict):
      param_groups = [{'params': param_groups}]

    # To store the running weighted average in Adam
    # ------------------------------------------------
    self.m = []
    self.v = []
    for param_group in param_groups:
      self.add_param_group(param_group)

      m_tmp = []
      v_tmp = []
      for param in param_group['params']:
        m_tmp += [Variable(torch.zeros(param.size(), device=self.device))]
        v_tmp += [Variable(torch.zeros(param.size(), device=self.device))]

      self.m += [m_tmp]
      self.v += [v_tmp]
    # ------------------------------------------------

    self.step = 0

  def set_lr(self, lr):
    for group in self.param_groups:
      group['lr'].data.copy_(
          torch.tensor([self.inverse_transforms['lr'](lr)], device=self.device)
      )

  def factor_lr(self, factor):
    for group in self.param_groups:
      group['lr'].data = group['lr'].data + \
                         torch.tensor([self.inverse_transforms['lr'](factor)], device=self.device)

  def parameters(self):
    tune_params = []
    for key in self.tune:
      tune_params += [group[key] for group in self.param_groups]
    return tune_params

  def parameters_with_names(self, apply_transforms=True):
    param_list = []
    for group in self.param_groups:
      d = {}
      for key in self.tune:
        if apply_transforms:
          d[key] = self.transforms[key](group[key])
        else:
          d[key] = group[key]
      param_list.append(d)
    return param_list

  def update(self, take_step=False):
    if take_step:
      for group, m, v in zip(self.param_groups, self.m, self.v):
        lr = self.transforms['lr'](group['lr'])
        rho = self.transforms['rho'](group['rho'])
        epsilon = self.transforms['epsilon'](group['epsilon'])

        beta1 = self.beta1
        beta2 = self.beta2

        for param, m_i, v_i in zip(group['params'], m, v):
          d_p = param.grad.data

          if self.weight_decay != 0:
            d_p += self.weight_decay * param.data

          m_i.data = (1 - beta1) * d_p + beta1 * m_i.data  # First moment estimate
          v_i.data = (1 - beta2) * (d_p ** 2) + beta2 * v_i.data  # Second moment estimate
          m_i_hat = m_i / (1 - beta1**(self.step+1))  # bias correction
          v_i_hat = v_i / (1 - beta2**(self.step+1))

          delta = -lr * m_i_hat / (v_i_hat ** rho + epsilon)
          param.data = param.data + delta

      self.step += 1

    else:
      updated_model = copy.deepcopy(self.model)
      for param in updated_model.parameters():
        param.requires_grad = False

      updated_params = self.optim_parameters_fn(updated_model)
      updated_param_groups = list(updated_params)
      if not isinstance(updated_param_groups[0], dict):
        updated_param_groups = [{'params': updated_param_groups}]

      for group, updated_group, m, v in zip(self.param_groups, updated_param_groups, self.m, self.v):
        lr = self.transforms['lr'](group['lr'])
        rho = self.transforms['rho'](group['rho'])
        epsilon = self.transforms['epsilon'](group['epsilon'])

        beta1 = self.beta1
        beta2 = self.beta2

        for param, updated_param, m_i, v_i in zip(group['params'], updated_group['params'], m, v):
          d_p = param.grad

          if self.weight_decay != 0:
            d_p += self.weight_decay * param.data

          m_tmp = (1 - beta1) * d_p + beta1 * m_i  # First moment estimate
          v_tmp = (1 - beta2) * (d_p ** 2) + beta2 * v_i  # Second moment estimate
          m_tmp_hat = m_tmp / (1 - beta1**(self.step+1))  # bias correction
          v_tmp_hat = v_tmp / (1 - beta2**(self.step+1))

          delta = -lr * m_tmp_hat / (v_tmp_hat ** rho + epsilon)
          updated_param += delta

      return updated_model

  def add_param_group(self, param_group):
    r"""Add a param group to the :class:`Optimizer` s `param_groups`.

    This can be useful when fine tuning a pre-trained network as frozen layers can be made
    trainable and added to the :class:`Optimizer` as training progresses.

    Arguments:
        param_group (dict): Specifies what Tensors should be optimized along with group
        specific optimization options.
    """
    params = param_group['params']
    if isinstance(params, torch.Tensor):
      param_group['params'] = [params]
    elif isinstance(params, set):
      raise TypeError('optimizer parameters need to be organized in ordered collections, but '
                      'the ordering of tensors in sets will change between runs. Please use a list instead.')
    else:
      param_group['params'] = list(params)

    for param in param_group['params']:
      if not isinstance(param, torch.Tensor):
        raise TypeError("optimizer can only optimize Tensors, "
                        "but one of the params is " + torch.typename(param))
      if not param.requires_grad:
        raise ValueError("optimizing a parameter that doesn't require gradients")
      if not param.is_leaf:
        raise ValueError("can't optimize a non-leaf Tensor")

    for name, default in self.defaults.items():
      if default is required and name not in param_group:
        raise ValueError("parameter group didn't specify a value of required optimization parameter " + name)
      else:
        # param_group.setdefault(name, default)
        param_group.setdefault(name, torch.tensor([default], device=self.device, dtype=torch.float32, requires_grad=True))

    param_set = set()
    for group in self.param_groups:
      param_set.update(set(group['params']))

    if not param_set.isdisjoint(set(param_group['params'])):
      raise ValueError("some parameters appear in more than one parameter group")

    self.param_groups.append(param_group)

  def zero_grad(self):
    r"""Clears the gradients of all optimized :class:`torch.Tensor` s."""
    for group in self.param_groups:
      for p in group['params']:
        if p.grad is not None:
          p.grad.detach_()
          p.grad.zero_()


class MyRMSprop:
  def __init__(self, model, optim_parameters_fn=None, lr=1e-4, rho=0.5,
               gamma=0.9, epsilon=1e-8, tune=['lr'], weight_decay=0,
               device='cuda:0', requires_grad=True):

    self.model = model
    self.model_copy = copy.deepcopy(model)

    if optim_parameters_fn is None:
      def optim_parameters_fn(model):
        return model.parameters()

    self.optim_parameters_fn = optim_parameters_fn
    self.tune = tune
    self.weight_decay = weight_decay
    self.device = device
    self.requires_grad = requires_grad

    self.transforms = {
        'lr': torch.exp,
        'rho': F.softplus,
        'gamma': F.sigmoid,
        'epsilon': torch.exp
    }

    self.inverse_transforms = {
        'lr': np.log,
        'rho': utils.inverse_softplus,
        'gamma': utils.logit,
        'epsilon': np.log
    }

    self.defaults = {
        'lr': self.inverse_transforms['lr'](lr),
        'rho': self.inverse_transforms['rho'](rho),
        'gamma': self.inverse_transforms['gamma'](gamma),
        'epsilon': self.inverse_transforms['epsilon'](epsilon),
    }

    params = optim_parameters_fn(model)

    self.param_groups = []
    param_groups = list(params)
    if not isinstance(param_groups[0], dict):
      param_groups = [{'params': param_groups}]

    # To store the running weighted average in RMSprop
    # ------------------------------------------------
    self.s = []
    for param_group in param_groups:
      self.add_param_group(param_group)

      s_tmp = []
      for param in param_group['params']:
        s_tmp += [Variable(torch.ones(param.size(), device=device, dtype=torch.float32))]

      self.s += [s_tmp]
    # ------------------------------------------------

  def set_lr(self, lr):
    for group in self.param_groups:
      group['lr'].data.copy_(
          torch.tensor([self.inverse_transforms['lr'](lr)], device=self.device)
      )

  def set_rho(self, rho):
    for group in self.param_groups:
      group['rho'].data.copy_(
          torch.tensor([self.inverse_transforms['rho'](rho)], device=self.device)
      )

  def set_parameters(self, param_list):
    for (group, params) in zip(self.param_groups, param_list):
      for key in params:
        group[key].data.copy_(params[key].data)

  def factor_lr(self, factor):
    for group in self.param_groups:
      # group['lr'] += self.transforms['lr'](lr)
      group['lr'] = group['lr'].detach() + \
                    torch.tensor([self.inverse_transforms['lr'](factor)], device=self.device)

  def parameters(self):
    tune_params = []
    for key in self.tune:
      tune_params += [group[key] for group in self.param_groups]
    return tune_params

  def parameters_with_names(self, apply_transforms=True):
    param_list = []
    for group in self.param_groups:
      d = {}
      for key in self.tune:
        if apply_transforms:
          d[key] = self.transforms[key](group[key])
        else:
          d[key] = group[key]
      param_list.append(d)
    return param_list

  def get_effective_lr(self):
    effective_lr_list = []
    for group, s in zip(self.param_groups, self.s):
      d = {}

      lr = self.transforms['lr'](group['lr'])
      rho = self.transforms['rho'](group['rho'])
      gamma = self.transforms['gamma'](group['gamma'])
      epsilon = self.transforms['epsilon'](group['epsilon'])

      for param, s_i in zip(group['params'], s):
        s_tmp = gamma * s_i + (1 - gamma) * (param.grad ** 2)
        effective_lr = -lr / (s_tmp ** rho + epsilon)
        d['effective_lr'] = effective_lr

      effective_lr_list.append(d)
    return effective_lr_list

  def update(self, take_step=False):
    if take_step:
      for group, s in zip(self.param_groups, self.s):
        lr = self.transforms['lr'](group['lr'])
        rho = self.transforms['rho'](group['rho'])
        gamma = self.transforms['gamma'](group['gamma'])
        epsilon = self.transforms['epsilon'](group['epsilon'])

        for param, s_i in zip(group['params'], s):
          d_p = param.grad.data

          if self.weight_decay != 0:
            d_p += self.weight_decay * param.data

          s_i.data = gamma.data * s_i.data + (1 - gamma.data) * (d_p ** 2)
          delta = -lr.data * (d_p / (s_i.data ** rho.data + epsilon.data))
          param.data = param.data + delta.data
    else:
      updated_model = copy.deepcopy(self.model)
      for param in updated_model.parameters():
        param.requires_grad = False

      updated_params = self.optim_parameters_fn(updated_model)
      updated_param_groups = list(updated_params)
      if not isinstance(updated_param_groups[0], dict):
        updated_param_groups = [{'params': updated_param_groups}]

      for group, updated_group, s in zip(self.param_groups, updated_param_groups, self.s):

        lr = self.transforms['lr'](group['lr'])
        rho = self.transforms['rho'](group['rho'])
        gamma = self.transforms['gamma'](group['gamma'])
        epsilon = self.transforms['epsilon'](group['epsilon'])

        for param, updated_param, s_i in zip(group['params'], updated_group['params'], s):
          d_p = param.grad

          if self.weight_decay != 0:
            d_p += self.weight_decay * param.data

          s_tmp = gamma * s_i + (1 - gamma) * (d_p ** 2)
          delta = -lr * (d_p / (s_tmp ** rho + epsilon))
          updated_param += delta

      return updated_model

  def add_param_group(self, param_group):
    r"""Add a param group to the :class:`Optimizer` s `param_groups`.

    This can be useful when fine tuning a pre-trained network as frozen layers can be made
    trainable and added to the :class:`Optimizer` as training progresses.

    Arguments:
        param_group (dict): Specifies what Tensors should be optimized along with group
        specific optimization options.
    """
    params = param_group['params']
    if isinstance(params, torch.Tensor):
      param_group['params'] = [params]
    elif isinstance(params, set):
      raise TypeError('optimizer parameters need to be organized in ordered collections, but '
                      'the ordering of tensors in sets will change between runs. Please use a list instead.')
    else:
      param_group['params'] = list(params)

    for param in param_group['params']:
      if not isinstance(param, torch.Tensor):
        raise TypeError("optimizer can only optimize Tensors, "
                        "but one of the params is " + torch.typename(param))
      if not param.requires_grad:
        raise ValueError("optimizing a parameter that doesn't require gradients")
      if not param.is_leaf:
        raise ValueError("can't optimize a non-leaf Tensor")

    for name, default in self.defaults.items():
      if default is required and name not in param_group:
        raise ValueError("parameter group didn't specify a value of required optimization parameter " +
                         name)
      else:
        # param_group.setdefault(name, torch.tensor([default], device=self.device, requires_grad=True))
        param_group.setdefault(name, torch.tensor([default], device=self.device, dtype=torch.float32, requires_grad=self.requires_grad))

    param_set = set()
    for group in self.param_groups:
      param_set.update(set(group['params']))

    if not param_set.isdisjoint(set(param_group['params'])):
      raise ValueError("some parameters appear in more than one parameter group")

    self.param_groups.append(param_group)

  def zero_grad(self):
    r"""Clears the gradients of all optimized :class:`torch.Tensor` s."""
    for group in self.param_groups:
      for p in group['params']:
        if p.grad is not None:
          p.grad.detach_()
          p.grad.zero_()
