import torch
import torch.nn as nn

from utils import unravel_model_params
from utils import AverageMeter, ravel_model_params, unravel_model_params
from anchored_sgd import AnchoredSGD


class ServerClientBase(object):
    def loss(self, model, output, target):
        if self.args.dataset == "vehicle" or self.args.dataset.startswith("logreg"):
            return self.criterion(model, output, target)
        else:
            return self.criterion(output, target)

    def pred(self, output):
        if isinstance(self.criterion, nn.BCELoss):
            return output.round()
        elif self.args.dataset == "vehicle":
            return output.sign()
        else:
            return output.max(1, keepdim=True)[1]


class Client(ServerClientBase):
    def __init__(self, id_, trainset, model, criterion, args):        
        self.id_ = id_
        self.criterion = criterion
        self.args = args

        if self.args.client_batch_size == -1:
            batch_size = len(trainset)
            self.train_loader = FullBatchDataLoader(trainset)
        else:
            batch_size = self.args.client_batch_size
            self.train_loader = torch.utils.data.DataLoader(
                trainset,
                batch_size=batch_size,
                shuffle=True, 
                **get_dataloader_kwargs(args.device)
            )

        self.trainset = trainset

    def step(self, model):
        raise NotImplementedError


class InexactFedDRClient(Client):
    def __init__(self, id_, trainset, model, criterion, args):
        super().__init__(id_, trainset, model, criterion, args)
        self.s_prev_model = model
        self.s_prev = ravel_model_params(self.s_prev_model)
        self.xbar_prev = ravel_model_params(self.s_prev_model)

    def step(self, alpha, phat):
        assert self.args.client_optimizer == "sgd", "Only SGD is supported for FedDR"

        # Update
        self.s_prev = self.s_prev - self.args.server_lmbd * alpha * (self.xbar_prev - phat)
        
        # Update model
        unravel_model_params(self.s_prev_model, self.s_prev)
        model = self.s_prev_model
        model.train()
        
        # Set anchor in optimizer
        optimizer = AnchoredSGD(model.parameters(), 
            lr=self.args.client_lr, 
            gamma=self.args.client_gamma,
            weight_decay=self.args.client_weight_decay,
        )
        optimizer.update_anchor(self.s_prev)

        loss_avg = AverageMeter()
        acc_avg = AverageMeter()

        while True:
            for epoch in range(self.args.client_num_epochs):
                loss_avg.reset()
                acc_avg.reset()
                for batch_idx, (data, target) in enumerate(self.train_loader):
                    data, target = data.to(self.args.device), target.to(self.args.device)
                    batch_size = data.shape[0]

                    # For fmnist, mnist we add the channel
                    if self.args.model == "cnn" and data.ndim == 3:
                        data = data.view(batch_size, 1, -1)

                    model.zero_grad()
                    output = model(data)
                    loss = self.loss(model, output, target)
                    loss.backward()
                    loss_avg.update(loss.item())
                    pred = self.pred(output)
                    acc = torch.sum(pred == target.view_as(pred)) / batch_size
                    acc_avg.update(acc)

                    optimizer.step()

            # Update xbar_prev
            xbar = ravel_model_params(model)
            self.xbar_prev = xbar
            
            # # Compute grad at final point
            model.zero_grad()
            output = model(data)
            loss = self.loss(model, output, target)
            loss.backward()
            xbar_grad = ravel_model_params(model, grads=True)

            # Returns raveled parameters
            yield loss, xbar, xbar_grad, self.s_prev


def get_dataloader_kwargs(device):
    return {'num_workers': 4, 'pin_memory': True} if device =='cuda' else {}


class FullBatchDataLoader(object):
    """Thin wrapper for full batch setting to avoid DataLoader.
    """
    def __init__(self, trainset):
        assert hasattr(trainset, 'data')
        assert hasattr(trainset, 'targets')

        # for `InMemoryDataset` transforms
        if hasattr(trainset, 'transform'):
            self.data = torch.zeros(trainset.data.shape)
            self.targets = torch.zeros_like(trainset.targets)
            for i in range(len(trainset)):
                self.data[i], self.targets[i] = trainset[i]
            
        else:
            self.data = trainset.data
            self.targets = trainset.targets

    def __iter__(self):
        yield (self.data, self.targets)

