import torch
import numpy as np
import matplotlib.pyplot as plt

d = 100 # dimension of w
# k = 5 # number of distributions
Q = 3
n = 14400 # local dataset size
T = 10 # number of clients
M = 2 # number of clusters

class LinearSystem():
    def __init__(self):
        # Random Setup
        '''
        w = torch.randn([k, d]) # ground truth weight
        pi = torch.ones([T,M])/M # Pi weights
        theta = torch.randn([M,d]) # W weights
        X = torch.randn([T,n,d])
        Y = []
        for i in range(T):
            j = int(i*k/T)
            Y.append(torch.matmul(X[i], w[j]) + 0.1 * torch.randn(n))
        Y = torch.stack(Y,dim=0) # Y:[T,n]
        '''
        self.device="cuda:0"
        # Outlier Setup
        self.sigma = 0.1
        self.weights_normal = torch.zeros(d)
        self.weights_outlier = torch.randn(d)
        self.weights_outlier /= torch.norm(self.weights_outlier)
        distance = torch.norm(self.weights_normal-self.weights_outlier)
        # print("R=",distance)
        self.weights = self.weights_normal.repeat(T).reshape(T,-1)
        self.weights[T-1] = self.weights_outlier
        deviation = torch.randn([T,d])
        self.r = 0.01
        self.Delta0 = 0.2
        deviation /= (torch.norm(deviation)/self.r)
        self.weights += deviation
        self.delta = 1
        self.Q = Q
        self.eta = 0.5

        # Initialize parameters:
        self.weights = self.weights.to(self.device)
        self.weights_normal = self.weights_normal.to(self.device)
        self.weights_outlier = self.weights_outlier.to(self.device)
        self.pi = (torch.ones([T,M])/M).to(self.device)
        self.theta = torch.randn([M,d]).to(self.device)/d
        # rand = torch.randn(d)
        # u = 4 * self.sigma*self.sigma / (distance.item()*d)
        # self.theta[0] = self.weights_normal + self.weights_outlier * (distance/2-self.r - self.Delta0) 
        # rand = torch.randn(d)
        # self.theta[1] = self.weights_outlier - self.weights_outlier * (distance/2-self.r - self.Delta0) 

    
    def pi_evolve(self, R, Delta0, n):
        return 1/(1 + (M-1)*np.exp(-2*R*Delta0*self.delta*self.delta*n))


    def sampling(self):
        X = torch.randn([T,n,d],device=self.device)*self.delta
        Y = []
        for i in range(T):
            Y.append(torch.matmul(X[i], self.weights[i]) + self.sigma * torch.randn(n,device=self.device))
        Y = torch.stack(Y,dim=0) # Y:[T,n]
        return X,Y


    def BayesOptimal(self, X, Y):
        opt = []
        opt_loss = []
        for i in range(T):
            mat = torch.inverse(torch.matmul(X[i].permute(1,0), X[i]))
            vec = torch.matmul(X[i].permute(1,0), Y[i])
            sol = torch.matmul(mat,vec)
            opt.append(sol)
            loss = torch.matmul(X[i], sol) - Y[i]
            opt_loss.append((loss*loss).mean())
        opt = torch.stack(opt,dim=0) # opt: [T,d]
        opt_loss = torch.stack(opt_loss, dim=0)
        return opt, opt_loss


    def Estep(self, X,Y):
        loss = torch.matmul(X.unsqueeze(1),self.theta.unsqueeze(-1)).squeeze(-1) - Y.expand(M,-1,-1).permute(1,0,2)
        loss = (loss * loss).mean(dim=-1)
        # print(torch.exp(-loss)[0])
        self.pi = self.pi * torch.exp(-loss)
        for t in range(T):
            self.pi[t] /= self.pi[t].sum()
        return
    
    def Mstep(self, X, Y, eta):
        for m in range(M):
            updated = []
            updated_weight = torch.zeros(d,device=self.device)
            for t in range(T):
                local_theta = self.theta[m]
                for tau in range(self.Q):
                    gradient = 2 * torch.matmul(X[t].T, torch.matmul(X[t],local_theta) - Y[t])
                    local_theta -= eta*gradient/X[t].shape[0]
                updated.append(local_theta.clone())
            for t in range(T):
                updated_weight += updated[t]*self.pi[t,m]/self.pi[:,m].sum()
            self.theta[m] = updated_weight
        return
    
    # def Mstep(self, X, Y, eta):
    #     for m in range(M):
    #         gradient_sum = 0
    #         for t in range(T):
    #             gradient = 2 * torch.matmul(X[t].T, torch.matmul(X[t],self.theta[m]) - Y[t])
    #             gradient_sum += self.pi[t,m]*gradient/X[t].shape[0]
    #         self.theta[m] -= eta * gradient_sum
    #     return

    def total_loss(self, X,Y):
        loss = torch.matmul(X.unsqueeze(1),self.theta.unsqueeze(-1)).squeeze(-1) - Y.expand(M,-1,-1).permute(1,0,2)
        loss = (loss * loss).mean(dim=-1) # shape: [T,M]
        return (self.pi*loss).sum(dim=0).sum(dim=0)
        
    def partial_loss(self, X,Y):
        loss = torch.matmul(X.unsqueeze(1),self.theta.unsqueeze(-1)).squeeze(-1) - Y.expand(M,-1,-1).permute(1,0,2)
        loss = (loss * loss).mean(dim=-1) # shape: [T,M]
        return (self.pi*loss).sum(dim=-1)
    
    def train(self):
        h = []
        step = []
        loss = []
        losses = []
        opt_losses = []
        dist = []
        thetas = []
        n_epoch = 300
        for epoch in range(n_epoch):
            X,Y = self.sampling()
            step.append(epoch)
            h.append(self.pi[:,0])
            thetas.append(self.theta.clone())
            dist.append(torch.norm(self.theta[0]).item())

            self.Estep(X,Y)
            self.Mstep(X,Y, self.eta)
            loss.append(self.total_loss(X,Y))
            losses.append(self.partial_loss(X,Y))
            opt_losses.append(self.BayesOptimal(X,Y)[1])


        h = torch.stack(h,dim=0)
        thetas = torch.stack(thetas,dim=0)
        losses = torch.stack(losses,dim=0) 
        opt_losses = torch.stack(opt_losses,dim=0)
        
        loss = torch.stack(loss,dim=0)
        return step, h, losses, loss, opt_losses, dist, thetas

    def train_avg(self):
        lr = 0.01
        theta = torch.randn(d,device=self.device)/d
        n_epoch = 300
        losses = []
        opt_losses = []
        for epoch in range(n_epoch):
            X,Y = self.sampling()
            updated = []
            updated_weight = torch.zeros(d,device=self.device)
            for t in range(T):
                local_theta = theta
                for tau in range(self.Q):
                    gradient = 2 * torch.matmul(X[t].T, torch.matmul(X[t],local_theta) - Y[t])
                    local_theta -= lr*gradient/X[t].shape[0]
                updated.append(local_theta.clone())
            for t in range(T):
                updated_weight += updated[t] / T
            theta = updated_weight
            loss = torch.matmul(X,theta.unsqueeze(-1)).squeeze(-1) - Y
            loss = (loss * loss).mean(dim=-1)
            losses.append(loss)
            opt_losses.append(self.BayesOptimal(X,Y)[1])
            
        losses = torch.stack(losses,dim=0)
        opt_losses = torch.stack(opt_losses,dim=0)
        return losses, opt_losses

    def visualize_pi(self, step, h):
        h = h.cpu()
        pi_predict = []
        length = min(len(step),200)
        plt.figure(figsize=(5,4),dpi=200)
        plt.xlabel("Epoch", fontsize=8)
        plt.ylabel(r'$\pi_{tk}$',fontsize=8)
        plt.ylim(0,1)
        plt.xlim(0,int(length)+1)
        plt.plot(step[:length],h[:length,1], markersize=1,lw=1,label=r'$t\in S_1$')
        plt.plot(step[:length],h[:length,-1],markersize=1,lw=1,label=r'$t\in S_2$')
        plt.legend(loc=0, fontsize=10)
        plt.savefig("figures/2M20T5G",bbox_inches = 'tight')

    def visualize_loss(self, step, losses, opt_loss):
        plt.figure(figsize=(5,4),dpi=100)
        losses = losses.cpu()
        opt_loss = opt_loss.cpu()
        excess_loss = losses[:,0] - opt_loss[:,0]
        plt.plot(step, excess_loss, 'k-',lw=1)
        plt.savefig("figures/linear_model",bbox_inches = 'tight')

    # def visualize_weight(self, step, thetas):
    #     plt.figure(figsize=(5,4),dpi=200)
    #     distance = torch.norm(thetas[:,0,:]-self.weights_normal, dim = -1)
    #     # print((thetas[:,0,:]-self.weights_normal).shape,distance)
    #     plt.xlabel("Epoch", fontsize=8)
    #     plt.xlim(0,len(step)+1)
    #     plt.ylabel(r'$||w_k - \mu_k^*||_2$', fontsize=8)
    #     plt.plot(step, distance.cpu(),'k-',lw=1,label='weight optimization')
    #     plt.legend(loc=0,fontsize=10)
    #     plt.savefig("figures/weights", bbox_inches = 'tight')

    def fairness(self, step, losses, opt_loss, losses_avg, opt_loss_avg):
        excess_loss = losses - opt_loss
        fairness = torch.max(torch.max(excess_loss.unsqueeze(1) - excess_loss.unsqueeze(2),dim=-1)[0],dim=-1)[0]
        excess_loss_avg = losses_avg - opt_loss_avg
        fairness_avg = torch.max(torch.max(excess_loss_avg.unsqueeze(1) - excess_loss_avg.unsqueeze(2),dim=-1)[0],dim=-1)[0]
        plt.figure(figsize=(5,4),dpi=200)
        plt.xlabel("Epoch",fontsize=8)
        plt.ylabel("Fairness",fontsize=8)
        plt.plot(step, fairness.cpu(),lw=1,label='Fed EM')
        plt.plot(step, fairness_avg.cpu(),lw=1, label='Fed Avg')
        plt.legend(loc = 0, fontsize=10)
        plt.savefig("figures/fairness",bbox_inches='tight')
        print("FAA\nFedAvg: {:.3f}, FOCUS: {:.3f}".format(fairness_avg[-1],fairness[-1]))

def main():
    system = LinearSystem()
    losses_avg, opt_losses_avg = system.train_avg()
    step, h, losses, loss, opt_loss, dist, thetas = system.train()
    system.visualize_pi(step, h)
    print(losses[-1].mean(),losses_avg[-1].mean())
    # system.visualize_loss(step, losses,opt_loss)
    # system.visualize_weight(step,thetas)
    system.fairness(step, losses, opt_loss, losses_avg, opt_losses_avg)

if __name__ == '__main__':
    main()