import copy

import torch
from torch import nn
from torchmeta.modules import MetaModule
from torchmeta.utils.gradient_based import gradient_update_parameters


class PerFedAvgSSLOptimizer:
    """
    1st and 2nd order MAML
    """

    def __init__(self, device, global_model: MetaModule, kwargs):
        self.device = device
        self.global_model = global_model

        self.global_lr = kwargs['global_lr']  # beta
        self.local_lr = kwargs['local_lr']  # alpha

        self.momentum = kwargs['momentum']
        self.wd = kwargs['wd']
        self.is_first_order = kwargs['is_first_order']
        self.accumulation_steps = kwargs['accumulation_steps']

        self.optimizer = torch.optim.SGD(self.global_model.parameters(), lr=self.global_lr, momentum=self.momentum,
                                         weight_decay=self.wd)
        self.criterion = nn.CrossEntropyLoss().to(device)

    def step(self, batch_idx, input):
        # the paper uses three batches, the intention is unclear, normally in MAML, we only have meta train and meta test
        # the actual batch size needs to be scaled by (3 * inner_steps), currently is scaled by (2)
        (x, labels) = input
        x_meta_train, x_meta_test = x.chunk(2)
        label_meta_train, label_meta_test = labels.chunk(2)

        logits = self.global_model(x_meta_train)
        local_loss = self.criterion(logits, label_meta_train)

        # TODO(ZY): check the correctness of accumulating gradient
        params = gradient_update_parameters(self.global_model, local_loss, step_size=self.local_lr,
                                            first_order=self.is_first_order)

        # meta test
        # TODO(ZY): may be stop accumulating stats in meta test?
        logits = self.global_model(x_meta_test, params=params)
        global_loss = self.criterion(logits, label_meta_test)
        global_loss.backward()

        if (batch_idx + 1) % self.accumulation_steps == 0:
            self.optimizer.step()
            self.optimizer.zero_grad()

        return global_loss


class FOPerFedAvg:
    """
    1st order MAML
    """

    def __init__(self, device, model, local_lr, global_lr, local_steps, **kwargs):
        self.device = device
        self.model = model
        self.local_lr = local_lr
        self.global_lr = global_lr
        self.local_steps = local_steps
        self.local_opt = torch.optim.SGD(self.model.parameters(), lr=local_lr)
        self.global_opt = torch.optim.SGD(self.model.parameters(), lr=global_lr)

    def step(self, x):
        x1, x2 = x.split(2)
        global_loss = None

        for _ in range(self.local_steps):
            temp_model = copy.deepcopy(list(self.model.parameters()))

            self.local_opt.zero_grad()
            local_loss = self.model.loss_fn(self.model(x1))
            local_loss.backward()
            self.local_opt.step()

            self.global_opt.zero_grad()
            global_loss = self.model.loss_fn(self.model(x2))
            global_loss.backward()

            # restore the model parameters to the one before first update
            for old_p, new_p in zip(self.model.parameters(), temp_model):
                old_p.data = new_p.data.clone()

            self.global_opt.step()

        return global_loss
