import numpy as np 
import torch
import logging
import wandb

from api.base_api import Base_API
from utils import WeightsSolver


class Local_API(Base_API):
    def run(self):
        for round_idx in range(self.args.rounds):
            logging.info("#####Round : {}".format(round_idx+1))

            self.train(weights=[1.]*len(self.internal_clients) + [0.]*len(self.external_clients), already_trained=False)
            self.test_and_show(round_idx)


class FedAvg_API(Base_API):
    def run(self, mirror=0.):
        # self.pretrain(self.args.pretrain_rounds)

        for round_idx in range(self.args.rounds):
            logging.info("#####Round : {}".format(round_idx+1))

            selected_clients = np.random.randint(0, len(self.internal_clients)+len(self.external_clients), self.args.expected_colaborators)

            weights = [0.] * (len(self.internal_clients)+len(self.external_clients))
            for cid in selected_clients.tolist():
                weights[cid] = 1.

            self.train(weights=weights, already_trained=False, mirror=mirror if round_idx > 0.5 * self.args.rounds else 0.)
            self.test_and_show(round_idx)


class FedProx_API(FedAvg_API):
    def __init__(self, data, device, args, model, criterion, accuracy):
        assert args.lambda_prox > 0.
        super().__init__(data, device, args, model, criterion, accuracy)


class PerFL_API(FedAvg_API):
    def __init__(self, data, device, args, model, criterion, accuracy):
        super().__init__(data, device, args, model, criterion, accuracy)

    def run(self):
        super().run(mirror=1e-5)


class Loss_API(Base_API):
    def __init__(self, data, device, args, model, criterion, accuracy, lower_weight_on_higher_loss):
        super().__init__(data, device, args, model, criterion, accuracy)
        self.losses = []
        self.lower_weight_on_higher_loss = lower_weight_on_higher_loss
    
    def run(self):
        # self.pretrain(self.args.pretrain_rounds)

        for round_idx in range(self.args.rounds):
            logging.info("#####Round : {}".format(round_idx+1))

            for client in self.internal_clients:
                client.receive_model(self.model)
                client.train()

            sorted_records = []
            for id, client in enumerate(self.external_clients):
                client.receive_model(self.model)
                loss = client.train(return_type='loss')
                sorted_records.append((id, loss))

            if self.lower_weight_on_higher_loss == True:
                sorted_records.sort(key=lambda record: record[1], reverse=False)
                constr = 'eq'
            else:
                sorted_records.sort(key=lambda record: record[1], reverse=True)
                constr = 'le'

            self.losses.clear()
            K = self.args.expected_colaborators
            for record in sorted_records[:K]:
                self.losses.append(record[1])
            
            weightsSolver = WeightsSolver(K, K, self.objective, constr)
            optimal_weights = weightsSolver.solve()
            weights = [0.] * len(self.external_clients)
            for record, weight in zip(sorted_records[:K], optimal_weights):
                weights[record[0]] = weight

            self.train([1.]*len(self.internal_clients)+weights, already_trained=True)
            self.test_and_show(round_idx)
    
    def objective(self, weights):
        if self.lower_weight_on_higher_loss: # Song et al. 2021
            return (weights * self.losses).sum()
        else: # Cho et al. 2020
            return -(weights * self.losses).sum()


class Cho_API(Loss_API):
    def __init__(self, data, device, args, model, criterion, accuracy):
        super().__init__(data, device, args, model, criterion, accuracy, lower_weight_on_higher_loss=False)


class Song_API(Loss_API):
    def __init__(self, data, device, args, model, criterion, accuracy):
        super().__init__(data, device, args, model, criterion, accuracy, lower_weight_on_higher_loss=True)


class Chen_API(Base_API):
    def __init__(self, data, device, args, model, criterion, accuracy):
        super().__init__(data, device, args, model, criterion, accuracy)
        self.gradient_norms = []
    
    def run(self):
        # self.pretrain(self.args.pretrain_rounds)

        for round_idx in range(self.args.rounds):
            logging.info("#####Round : {}".format(round_idx+1))

            for client in self.internal_clients:
                client.receive_model(self.model)
                client.train()

            sorted_records = []
            for id, client in enumerate(self.external_clients):
                client.receive_model(self.model)
                gradient_norm = client.train(return_type='gradient_norm')
                sorted_records.append((id, gradient_norm))
                
            sorted_records.sort(key=lambda record: record[1], reverse=True)

            self.gradient_norms.clear()
            K = self.args.expected_colaborators
            for record in sorted_records[:K]:
                self.gradient_norms.append(record[1])

            weightsSolver = WeightsSolver(K, K, self.objective, constr='eq')
            optimal_weights = weightsSolver.solve()
            weights = [0.] * len(self.external_clients)
            for record, weight in zip(sorted_records[:K], optimal_weights):
                weights[record[0]] = weight

            self.train([1.]*len(self.internal_clients)+weights, already_trained=True)
            self.test_and_show(round_idx)
    
    def objective(self, weights): # Chen et al. 2020
        return ((1/weights-1)*self.gradient_norms*self.gradient_norms).sum()

