from abc import ABC, abstractclassmethod
import numpy as np
import torch
import logging
import wandb
import copy

from torch.utils.data import TensorDataset


class Client(object):
    def __init__(self, cid, train_dataset, device, args, criterion, accuracy, test_dataset=None):
        self.cid = cid
        logging.info("Setting up client {}...".format(self.cid))
        
        self.train_dataset = train_dataset
        self.full_batch = None
        self.train_dataloader = None
        self.device = device 
        self.args = args
        self.criterion = criterion 
        self.accuracy = accuracy
        self.test_dataset = test_dataset

        self.model = None

        self.gradient = None
        self.server_gradient = None

        self.train_loss = 0.
        self.train_acc = 0.

        self.bytes_cnt = 0

    def train(self, full_batch, return_type='loss', mirror=0.):
        device = self.device 

        server_model = self.model
        model = copy.deepcopy(server_model)
        model.to(device)

        if self.args.optimizer == "sgd":
            optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=self.args.lr, weight_decay=self.args.wd)
        elif self.args.optimizer == "adam":
            optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=self.args.lr, weight_decay=self.args.wd)
        else:
            raise NotImplementedError("Invalid_Optimizer")
        
        server_model.train()
        server_parameters = list(filter(lambda p: p.requires_grad, server_model.parameters()))
        model.train()
        parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
        
        criterion = self.criterion
        accuracy = self.accuracy
        self.train_loss = 0.
        self.train_acc = 0.
        train_total = 0.

        for _ in range(self.args.epochs):
            train_dataloader_iter = iter(self.train_dataloader)

            model.zero_grad()
            total_size = 0.

            for xs, ys in train_dataloader_iter:   
                xs = xs.to(device)
                ys = ys.to(device)
                last_size = ys.size(0)
                total_size += last_size

                for p in parameters:
                    if p.grad is not None:
                        p.grad.detach().mul_((total_size-last_size)/last_size)

                preds = model(xs)
                loss = criterion(preds, ys)
                loss.backward()
                
                self.train_loss += loss.item() * last_size
                tmp_correct, tmp_total = accuracy(preds, ys)
                self.train_acc += tmp_correct
                train_total += tmp_total

                for p in parameters:
                    p.grad.detach().mul_(last_size/total_size)
                
                if not full_batch:
                    break

            if mirror > 0.:
                old_parameters = [torch.tensor(p.data, device=self.device) for p in parameters]
                for p in parameters:
                    p.data -= mirror * p.grad

                model.zero_grad()
                total_size = 0.

                for xs, ys in train_dataloader_iter:   
                    xs = xs.to(device)
                    ys = ys.to(device)
                    last_size = ys.size(0)
                    total_size += last_size

                    for p in parameters:
                        if p.grad is not None:
                            p.grad.detach().mul_((total_size-last_size)/last_size)

                    preds = model(xs)
                    loss = criterion(preds, ys)
                    loss.backward()
                    
                    self.train_loss += loss.item() * last_size
                    tmp_correct, tmp_total = accuracy(preds, ys)
                    self.train_acc += tmp_correct
                    train_total += tmp_total

                    for p in parameters:
                        p.grad.detach().mul_(last_size/total_size)
                    
                    if not full_batch:
                        break

                for p, old_p in zip(parameters, old_parameters):
                    p.data = old_p

            optimizer.step()

            lambda_prox = self.args.lambda_prox
            assert lambda_prox >= 0
            for p_updated, p in zip(parameters, server_parameters):
                p_updated.data = p_updated - lambda_prox * self.args.lr * (p_updated - p)
        
        # report the train_loss/acc
        self.train_loss /= train_total
        self.train_acc /= train_total

        self.gradient = torch.cat([p_updated.reshape(-1) - p.reshape(-1) for p_updated, p in zip(parameters, server_parameters)])

        # self.gradient = self.gradient.cpu()
        if return_type == 'loss':
            return_value = self.train_loss
        elif return_type == 'gradient_norm':
            return_value = self.gradient.norm(2).item()
        elif return_type == 'gradient_variance':
            return_value = (self.gradient.dot(self.server_gradient).item(), ((self.gradient-self.server_gradient).norm(2).item())**2 )
        else:
            raise NotImplementedError("Unsupported_Return_Type")

        return return_value

    def receive_model(self, server_model, server_gradient=None):
        self.model = server_model
        self.server_gradient = server_gradient
        
    def send_back_gradient(self):
        self.bytes_cnt += len(self.gradient) * 4
        return self.gradient

    def read_gradient(self):
        return self.gradient


class InternalClient(Client):
    def __init__(self, cid, train_dataset, device, args, criterion, accuracy, test_dataset):
        super().__init__(cid, train_dataset, device, args, criterion, accuracy, test_dataset)
        self.full_batch = self.args.internal_batch_train == 'full'
        self.train_dataloader = torch.utils.data.DataLoader(self.train_dataset, batch_size=self.args.batch_size, shuffle=(not self.full_batch))
        self.test_dataloader = torch.utils.data.DataLoader(self.test_dataset, batch_size=self.args.batch_size, shuffle=False)

    def train(self, return_type='loss', mirror=0.):
        return super().train(full_batch=self.full_batch, return_type=return_type, mirror=mirror)

    def test(self):
        device = self.device

        model = self.model
        model.to(device)
        model.eval()

        criterion = self.criterion #.to(device)
        accuracy = self.accuracy

        metric = {
            'test_correct': 0,
            'test_total': 0,
            'test_loss': 0
        }

        with torch.no_grad():
            # Create data loader
            test_dataloader_iter = iter(self.test_dataloader)

            for batch_idx, (xs, ys) in enumerate(test_dataloader_iter):
                xs = xs.to(device)
                ys = ys.to(device)

                preds = model(xs)
                loss = criterion(preds, ys)
                test_correct, test_total = accuracy(preds, ys)
                assert test_correct/test_total <= 1

                metric['test_correct'] += test_correct
                metric['test_loss'] += loss.item() * ys.size(0)
                metric['test_total'] += test_total
        
        return metric

    def send_back_gradient(self):
        return self.gradient


class ExternalClient(Client):
    def __init__(self, cid, train_dataset, device, args, criterion, accuracy):
        super().__init__(cid, train_dataset, device, args, criterion, accuracy, test_dataset=None)
        self.full_batch = self.args.external_batch_train == 'full'
        self.train_dataloader = torch.utils.data.DataLoader(self.train_dataset, batch_size=self.args.batch_size, shuffle=(not self.full_batch))

    def train(self, return_type='loss', mirror=0.):
        return super().train(full_batch=self.full_batch, return_type=return_type, mirror=mirror)


class Base_API(ABC):
    def __init__(self, data, device, args, model, criterion, accuracy):
        '''init the api object
        Require: dataset is a dict with three keys 'internal_train_dataset', 'internal_test_dataset' and 'external_dataset'.
        '''
        self.internal_cid = data['internal_cid']
        self.internal_train_data = data['internal_train_data']
        self.internal_test_data = data['internal_test_data']
        self.external_cid = data['external_cid']
        self.external_data = data['external_data']

        self.internal_clients = []
        self.external_clients = []
        for cid in self.internal_cid:
            client_train_data = self.internal_train_data[cid]
            client_test_data = self.internal_test_data[cid]
            if args.dataset == 'shakespeare':
                client_train_dataset = TensorDataset(torch.LongTensor(client_train_data['x']), torch.LongTensor(client_train_data['y']))
                client_test_dataset = TensorDataset(torch.LongTensor(client_test_data['x']), torch.LongTensor(client_test_data['y']))
            else:
                client_train_dataset = TensorDataset(torch.FloatTensor(client_train_data['x']), torch.LongTensor(client_train_data['y']))
                client_test_dataset = TensorDataset(torch.FloatTensor(client_test_data['x']), torch.LongTensor(client_test_data['y']))
            client = InternalClient(cid, client_train_dataset, device, args, criterion, accuracy, client_test_dataset)
            self.internal_clients.append(client)
        for cid in self.external_cid:
            client_data = self.external_data[cid]
            if args.dataset == 'shakespeare':
                client_dataset = TensorDataset(torch.LongTensor(client_data['x']), torch.LongTensor(client_data['y']))
            else:
                client_dataset = TensorDataset(torch.FloatTensor(client_data['x']), torch.LongTensor(client_data['y']))
            client = ExternalClient(cid, client_dataset, device, args, criterion, accuracy)
            self.external_clients.append(client)

        self.device = device 
        self.args = args 
        self.model = model 
        self.model.to(self.device)

        self.criterion = criterion
        self.accuracy = accuracy 

        self.train_metric = None
        self.test_metric = None

    def train(self, weights, already_trained, already_received=False, mirror=0.):
        # take log
        engaged_id = []
        
        merged_clients = self.internal_clients + self.external_clients
    
        for id, (weight, client) in enumerate(zip(weights, merged_clients)):
            if weight > 1e-6:
                engaged_id.append((id, weight))
            else:
                weights[id] = 0.

        logging.info(engaged_id)

        gradients = None
        engaged_cid = [id for (id,_) in engaged_id]
        for id, client in enumerate(merged_clients):
            if id in engaged_cid:
                if not already_trained:
                    client.receive_model(self.model)
                    client.train(mirror=mirror)
                if already_received == False:
                    gradient = client.send_back_gradient()
                else:
                    gradient = client.read_gradient()
                if gradients is None:
                    gradients = torch.zeros([len(merged_clients), gradient.size(0)])
                gradients[id] = gradient

        weights = torch.tensor(weights, dtype=torch.float)
        weights = weights / sum(weights)
        g_final = torch.matmul(weights, gradients).to(self.device)

        train_loss = 0.
        train_acc = 0.
        for id, client in enumerate(merged_clients):
            train_loss += weights[id].item() * client.train_loss
            train_acc += weights[id].item() * client.train_acc
        self.train_metric = {'train_loss': train_loss, 'train_acc': train_acc}

        # filter the parameters requiring gradient
        parameters = list(filter(lambda p: p.requires_grad, self.model.parameters()))

        # update the merged gradient/update
        for p in parameters:
            size = p.reshape(-1).size(0)
            g, g_final = g_final.split([size, g_final.size(0)-size])
            p.data = p + g.reshape(p.shape)

    def test(self):
        test_metric = {
            'test_correct': 0,  
            'test_total': 0,  
            'test_loss': 0
        }  

        for client in self.internal_clients:
            client.receive_model(self.model)
            metric = client.test()
            test_metric['test_correct'] += metric['test_correct']
            test_metric['test_total'] += metric['test_total']
            test_metric['test_loss'] += metric['test_loss']

        self.test_metric = test_metric

    def count_bytes(self):
        all_clients = self.internal_clients + self.external_clients
        total_bytes = sum([client.bytes_cnt for client in all_clients])
        return total_bytes

    def test_and_show(self, round_idx):
        metric = self.train_metric
        train_loss = metric['train_loss']
        train_acc = metric['train_acc']

        wandb.log({"Train/Loss": train_loss, "round": round_idx+1})
        wandb.log({"Train/Acc": train_acc, "round": round_idx+1})
        logging.info("##########train_loss: {}, train_accuracy: {}".format(train_loss, train_acc))

        if (round_idx+1) % self.args.test_interval == 0:
            logging.info("##########Test After {}th Round".format(round_idx+1))

            total_bytes = self.count_bytes()
            self.test()
            metric = self.test_metric
            test_loss = metric['test_loss'] / metric['test_total']
            test_acc = metric['test_correct'] / metric['test_total']

            wandb.log({"Test/Loss": test_loss, "round": round_idx+1, "total_bytes": total_bytes})
            wandb.log({"Test/Acc": test_acc, "round": round_idx+1, "total_bytes": total_bytes})
            logging.info("##########test_loss: {}, test_accuracy: {}; total_bytes: {}".format(test_loss, test_acc, total_bytes))

