import logging
import torch
import numpy as np

import wandb

from api.base_api import Base_API
from utils import WeightsSolver


class VARSEL_API(Base_API):
    def __init__(self, data, device, args, model, criterion, accuracy, deterministic=True):
        super().__init__(data, device, args, model, criterion, accuracy)
        self.deterministic = deterministic
        self.external_variances = []
        self.internal_variance_sum = 0.
        self.Sigma_SigmaT__mu_Eye = None
    
    def run(self):
        # self.pretrain(self.args.pretrain_rounds)

        for round_idx in range(self.args.rounds):
            logging.info("#####Round : {}".format(round_idx+1))

            assert len(self.internal_clients) >= 2
            num_internal_clients = len(self.internal_clients)

            internal_gradients = None
            for i, client in enumerate(self.internal_clients):
                client.receive_model(self.model)
                client.train()
                gradient = client.send_back_gradient()
                if internal_gradients is None:
                    with torch.no_grad():
                        internal_gradients = torch.zeros([num_internal_clients, gradient.size(0)], device=self.device)
                internal_gradients[i] = gradient

            with torch.no_grad():
                averaged_weights = torch.tensor([1/num_internal_clients for _ in range(num_internal_clients)], device=self.device)
                g_int = torch.matmul(averaged_weights, internal_gradients)
                # internal_variances = []
                # for i in range(num_internal_clients):
                #     g = internal_gradients[i] - g_int 
                #     internal_variances.append((g.norm(2).item())**2 * num_internal_clients / (num_internal_clients-1))
                self.internal_variance_sum = (torch.norm(internal_gradients-g_int, p='fro')**2 * num_internal_clients / (num_internal_clients-1)).item()
                # logging.info(internal_variances)
            
            internal_variance_mean = self.internal_variance_sum/num_internal_clients
            logging.info("Internal averaged variance: {}".format(internal_variance_mean))

            sorted_records = []
            for id, client in enumerate(self.external_clients):
                client.receive_model(self.model, g_int)
                inner_product, gradient_variance = client.train(return_type='gradient_variance')
                gradient_variance_corrected = gradient_variance * num_internal_clients / (num_internal_clients+1) 
                # gradient_variance_corrected = gradient_variance - self.internal_variance_sum/num_internal_clients**2
                sorted_records.append((id, inner_product, gradient_variance_corrected))

            sorted_records.sort(key=lambda record: record[2], reverse=False)
            filtered_records = sorted_records

            
            #filtered_records = list(filter(lambda record: record[1]>0, sorted_records))
            
            self.external_variances.clear()
            K = self.args.expected_colaborators
            self.external_variances = [record[2] for record in filtered_records[:K]]

            logging.info(self.external_variances)
        
            weights_tilde = [0.] * len(self.external_clients)
            weightsSolver = WeightsSolver(len(self.external_variances), self.args.expected_colaborators, self.objective, constr='le')
            optimal_weights = weightsSolver.solve()

            engaged_clients = []
            for record, weight in zip(filtered_records[:K], optimal_weights):
                if weight > 0.01:
                    engaged_clients.append(record[0])
                    weights_tilde[record[0]] = weight
            num_engaged_clients = len(engaged_clients)
            
            weights = [0.] * len(self.external_clients)
            if num_engaged_clients > 0:
                external_gradients = None
                for i, id in enumerate(engaged_clients):
                    client = self.external_clients[id]
                    gradient = client.send_back_gradient()
                    if external_gradients is None:
                        with torch.no_grad():
                            external_gradients = torch.zeros([num_engaged_clients, gradient.size(0)], device=self.device)
                    external_gradients[i] = gradient

                with torch.no_grad():
                    Sigma = external_gradients - g_int
                    n = Sigma.size(0)
                    Sigma_SigmaT = torch.zeros([n, n], device=self.device)
                    for i in range(n):
                        for j in range(i+1):
                            Sigma_SigmaT[i][j] = Sigma_SigmaT[j][i] = Sigma[i].dot(Sigma[j])
                    # Sigma_SigmaT = torch.matmul(external_gradients-g_int, (external_gradients-g_int).T)
                    self.Sigma_SigmaT__mu_Eye = np.array((Sigma_SigmaT - self.internal_variance_sum / num_internal_clients**2 * torch.eye(len(engaged_clients), device=self.device)).cpu())

                weightsSolver = WeightsSolver(num_engaged_clients, self.args.expected_colaborators, self.objective_refined, constr='relaxed')
                optimal_weights = weightsSolver.solve()

                for id, weight in zip(engaged_clients, optimal_weights):
                    weights[id] = weight

            wandb.log({"Weight": sum(weights), "round": round_idx+1})

            weights = [1.] * num_internal_clients + weights
            self.train(weights, already_trained=True, already_received=True)
            self.test_and_show(round_idx)
    
    def objective(self, weights): # VARSEL
        S = weights.sum()
        M = len(self.internal_clients)
        I = self.internal_variance_sum
        return (I + sum(weights**2 * self.external_variances)) / (M+S)**2
        # return 1/(M**2+(2*M+1)*S) * I + sum((weights/((2*M+3)*S-(2*M+2)*weights+(M+1)**2)) * self.external_variances)

    def objective_refined(self, weights): # VARSEL
        S = weights.sum()
        M = len(self.internal_clients)
        I = self.internal_variance_sum
        return (I + weights.T.dot(self.Sigma_SigmaT__mu_Eye).dot(weights)) / (M+S)**2

        