import torch
import numpy as np
import matplotlib.pyplot as plt
from linear_model import LinearSystem

d = 100 # dimension of w
n = 14400 # local dataset size
Q = 5
T = 10 # number of clients
M = 2 # number of clusters

class LinearSystemAFL(LinearSystem):
    def trainAFL(self):
        lr = 0.01
        self.lamb = 0.01
        theta = torch.randn(d,device=self.device)/d
        resulted_theta = theta.clone()
        n_epoch = 300
        losses = []
        opt_losses = []
        self.dynamic_lambdas = torch.ones(T) * 1.0 / T
        for epoch in range(n_epoch):
            X,Y = self.sampling()
            updated = []

            loss = torch.matmul(X,theta.unsqueeze(-1)).squeeze(-1) - Y
            loss = (loss * loss).mean(dim=-1)
            for t in range(T):
                local_theta = theta
                for tau in range(self.Q):
                    gradient = 2 * torch.matmul(X[t].T, torch.matmul(X[t],local_theta) - Y[t])
                    local_theta -= lr*gradient/X[t].shape[0]
                updated.append(local_theta.clone())
            agg_weights = self.dynamic_lambdas / torch.sum(self.dynamic_lambdas)
            grads = [ (theta - model)/lr for model in updated]
            for t in range(T):
                theta -= grads[t] * lr * agg_weights[t]

            self.dynamic_lambdas = [self.dynamic_lambdas[i] + self.lamb * loss[i] for i in range(T)]
            self.dynamic_lambdas = self.project(self.dynamic_lambdas)
            

            resulted_theta =  (resulted_theta * epoch + theta)/(epoch+1) 
            
            
        
            loss_resulted = torch.matmul(X,resulted_theta.unsqueeze(-1)).squeeze(-1) - Y
            loss_resulted = (loss_resulted * loss_resulted).mean(dim=-1)
            losses.append(loss_resulted)
            opt_losses.append(self.BayesOptimal(X,Y)[1])

        losses = torch.stack(losses,dim=0)
        opt_losses = torch.stack(opt_losses,dim=0)
        return losses, opt_losses

    def project(self, p):
        p = [ p_i.detach().cpu().numpy() for p_i in p]
        
        u = sorted(p, reverse=True)
        res = []
        rho = 0
        for i in range(len(p)):
            if (u[i] + (1.0 / (i + 1)) * (1 - np.sum(np.asarray(u)[:i + 1]))) > 0:
                rho = i + 1
        lmbd = (1.0 / rho) * (1 - np.sum(np.asarray(u)[:rho]))
        for i in range(len(p)):
            res.append(max(p[i] + lmbd, 0))
        res =  torch.from_numpy(np.array(res)) 
        return res

    def fairness(self, step, losses, opt_loss):
        excess_loss = losses - opt_loss
        fairness = torch.max(torch.max(excess_loss.unsqueeze(1) - excess_loss.unsqueeze(2),dim=-1)[0],dim=-1)[0]

        plt.figure(figsize=(5,4),dpi=200)
        plt.xlabel("Epoch",fontsize=8)
        plt.ylabel("Fairness",fontsize=8)
        plt.plot(step, fairness.cpu(),lw=1,label='FFL')
        plt.legend(loc = 0, fontsize=10)
        plt.savefig("figures/fairness_AFL",bbox_inches='tight')
        print("FAA: {:.3f}".format(fairness[-1]))

system = LinearSystemAFL()
losses, opt_loss = system.trainAFL()
step = np.arange(300)

system.fairness(step, losses, opt_loss)
print(losses[-1].mean())