import functools
from collections import OrderedDict
from copy import deepcopy
from typing import Iterable

import torch
import torch.nn as nn
from mmcv.runner import build_optimizer
from mmcv.runner.optimizer.builder import OPTIMIZERS
from mmcv.utils.registry import build_from_cfg
from torch.autograd import Variable
from torch.optim.optimizer import Optimizer


base_lr = 0.01
base_wd = 0.0001


def assert_equal(x, y):
    if isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor):
        torch.testing.assert_allclose(x, y.to(x.device))
    elif isinstance(x, OrderedDict) and isinstance(y, OrderedDict):
        for x_value, y_value in zip(x.values(), y.values()):
            assert_equal(x_value, y_value)
    elif isinstance(x, dict) and isinstance(y, dict):
        assert x.keys() == y.keys()
        for key in x.keys():
            assert_equal(x[key], y[key])
    elif isinstance(x, str) and isinstance(y, str):
        assert x == y
    elif isinstance(x, Iterable) and isinstance(y, Iterable):
        assert len(x) == len(y)
        for x_item, y_item in zip(x, y):
            assert_equal(x_item, y_item)
    else:
        assert x == y


class SubModel(nn.Module):

    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(2, 2, kernel_size=1, groups=2)
        self.gn = nn.GroupNorm(2, 2)
        self.fc = nn.Linear(2, 2)
        self.param1 = nn.Parameter(torch.ones(1))

    def forward(self, x):
        return x


class ExampleModel(nn.Module):

    def __init__(self):
        super().__init__()
        self.param1 = nn.Parameter(torch.ones(1))
        self.conv1 = nn.Conv2d(3, 4, kernel_size=1, bias=False)
        self.conv2 = nn.Conv2d(4, 2, kernel_size=1)
        self.bn = nn.BatchNorm2d(2)
        self.sub = SubModel()
        self.fc = nn.Linear(2, 1)

    def forward(self, x):
        return x


def check_lamb_optimizer(optimizer,
                         model,
                         bias_lr_mult=1,
                         bias_decay_mult=1,
                         norm_decay_mult=1,
                         dwconv_decay_mult=1):
    param_groups = optimizer.param_groups
    assert isinstance(optimizer, Optimizer)
    assert optimizer.defaults['lr'] == base_lr
    assert optimizer.defaults['weight_decay'] == base_wd
    model_parameters = list(model.parameters())
    assert len(param_groups) == len(model_parameters)
    for i, param in enumerate(model_parameters):
        param_group = param_groups[i]
        assert torch.equal(param_group['params'][0], param)
    # param1
    param1 = param_groups[0]
    assert param1['lr'] == base_lr
    assert param1['weight_decay'] == base_wd
    # conv1.weight
    conv1_weight = param_groups[1]
    assert conv1_weight['lr'] == base_lr
    assert conv1_weight['weight_decay'] == base_wd
    # conv2.weight
    conv2_weight = param_groups[2]
    assert conv2_weight['lr'] == base_lr
    assert conv2_weight['weight_decay'] == base_wd
    # conv2.bias
    conv2_bias = param_groups[3]
    assert conv2_bias['lr'] == base_lr * bias_lr_mult
    assert conv2_bias['weight_decay'] == base_wd * bias_decay_mult
    # bn.weight
    bn_weight = param_groups[4]
    assert bn_weight['lr'] == base_lr
    assert bn_weight['weight_decay'] == base_wd * norm_decay_mult
    # bn.bias
    bn_bias = param_groups[5]
    assert bn_bias['lr'] == base_lr
    assert bn_bias['weight_decay'] == base_wd * norm_decay_mult
    # sub.param1
    sub_param1 = param_groups[6]
    assert sub_param1['lr'] == base_lr
    assert sub_param1['weight_decay'] == base_wd
    # sub.conv1.weight
    sub_conv1_weight = param_groups[7]
    assert sub_conv1_weight['lr'] == base_lr
    assert sub_conv1_weight['weight_decay'] == base_wd * dwconv_decay_mult
    # sub.conv1.bias
    sub_conv1_bias = param_groups[8]
    assert sub_conv1_bias['lr'] == base_lr * bias_lr_mult
    assert sub_conv1_bias['weight_decay'] == base_wd * dwconv_decay_mult
    # sub.gn.weight
    sub_gn_weight = param_groups[9]
    assert sub_gn_weight['lr'] == base_lr
    assert sub_gn_weight['weight_decay'] == base_wd * norm_decay_mult
    # sub.gn.bias
    sub_gn_bias = param_groups[10]
    assert sub_gn_bias['lr'] == base_lr
    assert sub_gn_bias['weight_decay'] == base_wd * norm_decay_mult
    # sub.fc1.weight
    sub_fc_weight = param_groups[11]
    assert sub_fc_weight['lr'] == base_lr
    assert sub_fc_weight['weight_decay'] == base_wd
    # sub.fc1.bias
    sub_fc_bias = param_groups[12]
    assert sub_fc_bias['lr'] == base_lr * bias_lr_mult
    assert sub_fc_bias['weight_decay'] == base_wd * bias_decay_mult
    # fc1.weight
    fc_weight = param_groups[13]
    assert fc_weight['lr'] == base_lr
    assert fc_weight['weight_decay'] == base_wd
    # fc1.bias
    fc_bias = param_groups[14]
    assert fc_bias['lr'] == base_lr * bias_lr_mult
    assert fc_bias['weight_decay'] == base_wd * bias_decay_mult


def _test_state_dict(weight, bias, input, constructor):
    weight = Variable(weight, requires_grad=True)
    bias = Variable(bias, requires_grad=True)
    inputs = Variable(input)

    def fn_base(optimizer, weight, bias):
        optimizer.zero_grad()
        i = input_cuda if weight.is_cuda else inputs
        loss = (weight.mv(i) + bias).pow(2).sum()
        loss.backward()
        return loss

    optimizer = constructor(weight, bias)
    fn = functools.partial(fn_base, optimizer, weight, bias)

    # Prime the optimizer
    for _ in range(20):
        optimizer.step(fn)
    # Clone the weights and construct new optimizer for them
    weight_c = Variable(weight.data.clone(), requires_grad=True)
    bias_c = Variable(bias.data.clone(), requires_grad=True)
    optimizer_c = constructor(weight_c, bias_c)
    fn_c = functools.partial(fn_base, optimizer_c, weight_c, bias_c)
    # Load state dict
    state_dict = deepcopy(optimizer.state_dict())
    state_dict_c = deepcopy(optimizer.state_dict())
    optimizer_c.load_state_dict(state_dict_c)
    # Run both optimizations in parallel
    for _ in range(20):
        optimizer.step(fn)
        optimizer_c.step(fn_c)
        assert_equal(weight, weight_c)
        assert_equal(bias, bias_c)
    # Make sure state dict wasn't modified
    assert_equal(state_dict, state_dict_c)
    # Make sure state dict is deterministic with equal
    # but not identical parameters
    # NOTE: The state_dict of optimizers in PyTorch 1.5 have random keys,
    state_dict = deepcopy(optimizer.state_dict())
    state_dict_c = deepcopy(optimizer_c.state_dict())
    keys = state_dict['param_groups'][-1]['params']
    keys_c = state_dict_c['param_groups'][-1]['params']
    for key, key_c in zip(keys, keys_c):
        assert_equal(optimizer.state_dict()['state'][key],
                     optimizer_c.state_dict()['state'][key_c])
    # Make sure repeated parameters have identical representation in state dict
    optimizer_c.param_groups.extend(optimizer_c.param_groups)
    assert_equal(optimizer_c.state_dict()['param_groups'][0],
                 optimizer_c.state_dict()['param_groups'][1])

    # Check that state dict can be loaded even when we cast parameters
    # to a different type and move to a different device.
    if not torch.cuda.is_available():
        return

    input_cuda = Variable(inputs.data.float().cuda())
    weight_cuda = Variable(weight.data.float().cuda(), requires_grad=True)
    bias_cuda = Variable(bias.data.float().cuda(), requires_grad=True)
    optimizer_cuda = constructor(weight_cuda, bias_cuda)
    fn_cuda = functools.partial(fn_base, optimizer_cuda, weight_cuda,
                                bias_cuda)

    state_dict = deepcopy(optimizer.state_dict())
    state_dict_c = deepcopy(optimizer.state_dict())
    optimizer_cuda.load_state_dict(state_dict_c)

    # Make sure state dict wasn't modified
    assert_equal(state_dict, state_dict_c)

    for _ in range(20):
        optimizer.step(fn)
        optimizer_cuda.step(fn_cuda)
        assert_equal(weight, weight_cuda)
        assert_equal(bias, bias_cuda)

    # validate deepcopy() copies all public attributes
    def getPublicAttr(obj):
        return set(k for k in obj.__dict__ if not k.startswith('_'))

    assert_equal(getPublicAttr(optimizer), getPublicAttr(deepcopy(optimizer)))


def _test_basic_cases_template(weight, bias, inputs, constructor,
                               scheduler_constructors):
    """Copied from PyTorch."""
    weight = Variable(weight, requires_grad=True)
    bias = Variable(bias, requires_grad=True)
    inputs = Variable(inputs)
    optimizer = constructor(weight, bias)
    schedulers = []
    for scheduler_constructor in scheduler_constructors:
        schedulers.append(scheduler_constructor(optimizer))

    # to check if the optimizer can be printed as a string
    optimizer.__repr__()

    def fn():
        optimizer.zero_grad()
        y = weight.mv(inputs)
        if y.is_cuda and bias.is_cuda and y.get_device() != bias.get_device():
            y = y.cuda(bias.get_device())
        loss = (y + bias).pow(2).sum()
        loss.backward()
        return loss

    initial_value = fn().item()
    for _ in range(200):
        for scheduler in schedulers:
            scheduler.step()
        optimizer.step(fn)

    assert fn().item() < initial_value


def _test_basic_cases(constructor,
                      scheduler_constructors=None,
                      ignore_multidevice=False):
    """Copied from PyTorch."""
    if scheduler_constructors is None:
        scheduler_constructors = []
    _test_state_dict(
        torch.randn(10, 5), torch.randn(10), torch.randn(5), constructor)
    _test_basic_cases_template(
        torch.randn(10, 5), torch.randn(10), torch.randn(5), constructor,
        scheduler_constructors)
    # non-contiguous parameters
    _test_basic_cases_template(
        torch.randn(10, 5, 2)[..., 0],
        torch.randn(10, 2)[..., 0], torch.randn(5), constructor,
        scheduler_constructors)
    # CUDA
    if not torch.cuda.is_available():
        return
    _test_basic_cases_template(
        torch.randn(10, 5).cuda(),
        torch.randn(10).cuda(),
        torch.randn(5).cuda(), constructor, scheduler_constructors)
    # Multi-GPU
    if not torch.cuda.device_count() > 1 or ignore_multidevice:
        return
    _test_basic_cases_template(
        torch.randn(10, 5).cuda(0),
        torch.randn(10).cuda(1),
        torch.randn(5).cuda(0), constructor, scheduler_constructors)

