import unittest

import functools as ft
import itertools as it

from apex import amp
from apex.amp import _amp_state
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn import Parameter

from utils import common_init, HALF, FLOAT,\
    ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT

class MyModel(torch.nn.Module):
    def __init__(self, unique):
        super(MyModel, self).__init__()
        self.weight0 = Parameter(unique +
            torch.arange(2, device='cuda', dtype=torch.float32))
        self.weight1 = Parameter(1. + unique + torch.arange(2, device='cuda', dtype=torch.float16))

    @staticmethod
    def ops(input, weight0, weight1):
        return ((input*(weight0.float()))*(weight1.float())).sum()

    def forward(self, input):
        return self.ops(input, self.weight0, self.weight1)


# Abandon all hope, ye who enter here.


class TestAddParamGroup(unittest.TestCase):
    def setUp(self):
        self.x = torch.ones((2), device='cuda', dtype=torch.float32)
        common_init(self)

    def tearDown(self):
        pass

    def zero_grad(self, models, optimizer, how_to_zero):
        if how_to_zero == "none":
            for model in models:
                for param in model.parameters():
                    param.grad = None
        elif how_to_zero == "model":
            for model in models:
                model.zero_grad()
        elif how_to_zero == "optimizer":
            optimizer.zero_grad()

    def test_add_param_group(self):
        for opt_level in ("O0", "O1", "O2", "O3"):
          for zero_before_add in (True, False):
            for try_accumulation in (True, False):
              model0 = MyModel(1)
              model1 = MyModel(2)

              optimizer = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25}],
                                          momentum=0.125)

              optimizer.zero_grad()
              loss = model0(self.x)
              loss.backward()
              optimizer.step()

              if zero_before_add:
                  optimizer.zero_grad()
              optimizer.add_param_group({'params' : model1.parameters(), 'lr' : 0.5})
              if not zero_before_add:
                  optimizer.zero_grad()

              loss = model0(self.x) + model1(self.x)
              loss.backward(retain_graph=try_accumulation)
              if try_accumulation:
                  loss.backward()
              optimizer.step()

              # Once more to make sure the new params pick up momemtums properly
              optimizer.zero_grad()
              loss = model0(self.x) + model1(self.x)
              loss.backward(retain_graph=try_accumulation)
              if try_accumulation:
                  loss.backward()
              optimizer.step()

              reference_params = [param.data.clone() for param in model0.parameters()] + \
                                 [param.data.clone() for param in model1.parameters()]

              for how_to_zero in "none", "model", "optimizer":
                  model0 = MyModel(1)
                  model1 = MyModel(2)

                  optimizer = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25}],
                                              momentum=0.125)

                  _amp_state.allow_incoming_model_not_fp32 = True
                  [model0, model1], optimizer = amp.initialize([model0, model1],
                      optimizer,
                      opt_level=opt_level,
                      verbosity=0,
                      cast_model_type=False)
                  _amp_state.allow_incoming_model_not_fp32 = False

                  _amp_state.loss_scalers[0]._loss_scale = 4.0

                  self.zero_grad([model0, model1], optimizer, how_to_zero)
                  loss = model0(self.x)
                  with amp.scale_loss(loss, optimizer) as scaled_loss:
                      scaled_loss.backward()
                  optimizer.step()

                  if zero_before_add:
                      self.zero_grad([model0, model1], optimizer, how_to_zero)
                  optimizer.add_param_group({'params' : model1.parameters(), 'lr' : 0.5})
                  if not zero_before_add:
                      self.zero_grad([model0, model1], optimizer, how_to_zero)

                  loss = model0(self.x) + model1(self.x)
                  with amp.scale_loss(loss, optimizer) as scaled_loss:
                      scaled_loss.backward(retain_graph=try_accumulation)
                  if try_accumulation:
                      with amp.scale_loss(loss, optimizer) as scaled_loss:
                          scaled_loss.backward()
                  optimizer.step()

                  # Once more to make sure the new params pick up momentums properly
                  self.zero_grad([model0, model1], optimizer, how_to_zero)
                  loss = model0(self.x) + model1(self.x)
                  with amp.scale_loss(loss, optimizer) as scaled_loss:
                      scaled_loss.backward(retain_graph=try_accumulation)
                  if try_accumulation:
                      with amp.scale_loss(loss, optimizer) as scaled_loss:
                          scaled_loss.backward()
                  optimizer.step()

                  final_params = [param.data.clone() for param in model0.parameters()] + \
                                 [param.data.clone() for param in model1.parameters()]

                  for reference, final in zip(reference_params, final_params):
                      torch.testing.assert_close(reference.to(final.dtype), final,
                                      msg="opt_level = {}, how_to_zero = {}, zero_before_add = {}".format(
                                      opt_level, how_to_zero, zero_before_add))


if __name__ == '__main__':
    unittest.main()
