import torch
import numpy as np
import matplotlib.pyplot as plt
from linear_model import LinearSystem

d = 100 # dimension of w
n = 14400 # local dataset size
Q = 5
T = 10 # number of clients
M = 2 # number of clusters

class LinearSystemFFL(LinearSystem):
    def trainFFL(self):
        lr = 0.01
        self.q = 1
        theta = torch.randn(d,device=self.device)/d
        n_epoch = 300
        losses = []
        opt_losses = []
        for epoch in range(n_epoch):
            X,Y = self.sampling()
            updated = []
            updated_weight = torch.zeros(d,device=self.device)

            loss = torch.matmul(X,theta.unsqueeze(-1)).squeeze(-1) - Y
            loss = (loss * loss).mean(dim=-1)
            for t in range(T):
                local_theta = theta
                for tau in range(self.Q):
                    gradient = 2 * torch.matmul(X[t].T, torch.matmul(X[t],local_theta) - Y[t])
                    local_theta -= lr*gradient/X[t].shape[0]
                updated.append(local_theta.clone())

            grads = [ (theta - model)/lr for model in updated]
            Deltas_coeff =  [ torch.pow(loss[i] , self.q) for i in range(T)]
            hs = [self.q * torch.pow(loss[i], self.q-1)  * torch.pow(torch.linalg.norm(grads[i]), 2) + 1.0 / lr * torch.pow(loss[i], self.q)  for i in range(T)] 
            demominator = torch.stack(hs, dim=0).sum()
            scaled_deltas_coeff = [delta/demominator for delta in Deltas_coeff] 
            for t in range(T):
                theta -= scaled_deltas_coeff[t]*grads[t]
            
            losses.append(loss)
            opt_losses.append(self.BayesOptimal(X,Y)[1])
            
        losses = torch.stack(losses,dim=0)
        opt_losses = torch.stack(opt_losses,dim=0)
        return losses, opt_losses

    def fairness(self, step, losses, opt_loss):
        excess_loss = losses - opt_loss
        fairness = torch.max(torch.max(excess_loss.unsqueeze(1) - excess_loss.unsqueeze(2),dim=-1)[0],dim=-1)[0]

        plt.figure(figsize=(5,4),dpi=200)
        plt.xlabel("Epoch",fontsize=8)
        plt.ylabel("Fairness",fontsize=8)
        plt.plot(step, fairness.cpu(),lw=1,label='FFL')
        plt.legend(loc = 0, fontsize=10)
        plt.savefig("figures/fairness_FFL",bbox_inches='tight')
        print("FAA: {:.3f}".format(fairness[-1]))

system = LinearSystemFFL()
losses, opt_loss = system.trainFFL()
step = np.arange(300)

system.fairness(step, losses, opt_loss)
print(losses[-1].mean())