import unittest
import torch
from torch import nn
from torch.nn import functional as F
from torchmeta.modules import *
from per_fed_avg import PerFedAvgSSLOptimizer, FOPerFedAvg
from functools import partial
from collections import OrderedDict


class PerFedAvgTestCases(unittest.TestCase):
    def setUp(self) -> None:
        def init_toy(m):
            m.fc1.weight.data = torch.tensor([[1], [2]], dtype=torch.float, requires_grad=True)
            m.fc1.bias.data = torch.tensor([3, 4], dtype=torch.float, requires_grad=True)
            m.fc2.weight.data = torch.tensor([[5, 6]], dtype=torch.float, requires_grad=True)
            m.fc2.bias.data = torch.tensor([7], dtype=torch.float, requires_grad=True)

        self.toy_net = nn.Sequential(OrderedDict([
            ('fc1', nn.Linear(1, 2)),
            ('fc2', nn.Linear(2, 1))
        ]))
        self.toy_net.loss_fn = lambda y: y.mean()

        self.meta_toy_net = MetaSequential(OrderedDict([
            ('fc1', MetaLinear(1, 2)),
            ('fc2', MetaLinear(2, 1))
        ]))
        self.meta_toy_net.loss_fn = lambda y: y.mean()

        init_toy(self.toy_net)
        init_toy(self.meta_toy_net)

        self.toy_net.train()
        self.meta_toy_net.train()

    def test_first_order_PerFedAvg(self):
        global_lr = 0.2
        local_lr = 0.1
        local_steps = 3

        toy_net = self.toy_net
        meta_toy_net = self.meta_toy_net

        mini_batch = torch.tensor([[1], [1], [2], [2]], dtype=torch.float)

        fopfa_opt = FOPerFedAvg('cpu', toy_net, global_lr=global_lr, local_lr=local_lr, local_steps=local_steps)
        fopfa_opt.step(mini_batch)
        fopfa_opt.step(mini_batch)

        pfa_opt = PerFedAvgSSLOptimizer('cpu', meta_toy_net,
                                        dict(global_lr=global_lr, local_lr=local_lr, local_steps=local_steps, momentum=0, wd=0, first_order=True))
        pfa_opt.step(mini_batch)
        pfa_opt.step(mini_batch)

        # print(dict(toy_net.named_parameters()))
        # print(dict(meta_toy_net.named_parameters()))

        for toy_p, meta_toy_p in zip(toy_net.parameters(), meta_toy_net.parameters()):
            assert torch.equal(toy_p, meta_toy_p)

    def test_second_order_PerFedAvg(self):
        global_lr = 0.2
        local_lr = 0.1
        local_steps = 3

        meta_toy_net = self.meta_toy_net
        mini_batch = torch.tensor([[1], [1], [2], [2]], dtype=torch.float)

        pfa_opt = PerFedAvgSSLOptimizer('cpu', meta_toy_net,
                                        dict(global_lr=global_lr, local_lr=local_lr, local_steps=local_steps, momentum=0, wd=0, first_order=False))
        pfa_opt.step(mini_batch)
        pfa_opt.step(mini_batch)

        # manual toy net:
        manual_toy_net = torch.tensor([1, 2, 3, 4, 5, 6, 7], dtype=torch.float, requires_grad=True)

        def manual_forward(x, w):
            fc1_weight = w[0:2].view(2, 1)
            fc1_bias = w[2:4]
            fc2_weight = w[4:6].view(1, 2)
            fc2_bias = w[6]
            return ((x @ fc1_weight.T + fc1_bias) @ fc2_weight.T + fc2_bias).mean()

        # manual MAML
        for g_step in range(2):
            for l_step in range(local_steps):
                x1, x2 = mini_batch.split(2)

                local_loss = manual_forward(x1, manual_toy_net)
                local_grad = torch.autograd.grad(local_loss, manual_toy_net)[0]
                local_hess = torch.autograd.functional.hessian(partial(manual_forward, x1), manual_toy_net)

                manual_toy_net_temp = manual_toy_net.detach().clone().requires_grad_(True) - local_lr * local_grad

                global_loss = manual_forward(x2, manual_toy_net_temp)
                global_grad = torch.autograd.grad(global_loss, manual_toy_net_temp)[0]

                manual_toy_net = manual_toy_net - global_lr * (torch.eye(7) - local_lr * local_hess) @ global_grad

        # check
        manual_parameters = [manual_toy_net[0:2].view(2, 1), manual_toy_net[2:4], manual_toy_net[4:6].view(1, 2), manual_toy_net[6]]
        for meta_toy_p, manual_p in zip(meta_toy_net.parameters(), manual_parameters):
            assert torch.norm(meta_toy_p - manual_p) < 1e-5


if __name__ == '__main__':
    unittest.main()
