import logging
import pickle
from typing import Tuple
import math
import numpy as np
import torch
from node import BaseNode, DVRGTFW_Node, DeFW_Node, compute_loss, DsgFw_Node




def get_lambda(W:torch.Tensor):
    """compute the second largest eigen vals for the gossip matrix W

    Args:
        w (torch.Tensor): the gossip matrix W

    Returns:
        tensor.Tensor: lambda
    """
    WWT = W @ W.T
    eigenvals = torch.linalg.eigvalsh(WWT)
    sorted_eigvals, _ = torch.sort(eigenvals, descending=True)
    return sorted_eigvals[1].sqrt()

def get_K(lambdaval):
    return int(3/math.sqrt(1-lambdaval))

class BaseServer:
    def __init__(
        self,
        data: torch.Tensor,
        label: torch.Tensor,
        scale: float,
        n_nodes: int,
        gossip_matrix: torch.Tensor,
        batch_size: int = 1,
        base_lr: float = 1.0,
        is_convex: bool = True
    ) -> None:
        """Basic server class for each node in decentralized learning

        Args:
            data (torch.Tensor): x with the shape (m, d)
            label (torch.Tensor): y with the shape (m, )
            scale (float): the radio for the l1 region, i.e., lambda,
            n_nodes (int): the number of nodes in the diagram
            batch_size (int): batch size
            base_lr (float, optional): the basic learning rate
            is_convex (book, optional): if true, then for logitistic regression, else for non-convex problem
        """
        self.data = data
        self.label = label
        self.scale = scale
        self.n_nodes = n_nodes
        self.gm = torch.tensor(gossip_matrix, device=data.device, dtype=torch.float32)
        self.batch_size = batch_size
        self.base_lr = base_lr
        self.is_convex = is_convex
        self.num_sample = data.shape[0]
        self.dimension = data.shape[1]
        self.gm_lambda = get_lambda(self.gm)
        self.ifo_num = 0
    
    def split_data(self):
        total_num = len(self.label)
        average_num = total_num // self.n_nodes
        idx_list = np.arange(total_num)
        np.random.shuffle(idx_list)

        idx_split = [idx_list[i*average_num:(i+1)*average_num] for i in range(self.n_nodes)]
        return idx_split
    
    def init_nodes(self):
        idx_split = self.split_data()

        nodes = [BaseNode(
            self.data[idx_split[i]],
            self.label[idx_split[i]],
            self.scale,
            self.base_lr,
            self.is_convex,
        ) for i in range(self.n_nodes)]
        
        self.node_list = nodes

    def collect_nodes_param(self):
        tensors=[node.x for node in self.node_list]
        return torch.cat(tensors
        ).view(-1, self.dimension)  # (n, d)
    
    def collect_nodes_loss(self):
        tensors=[node.loss() for node in self.node_list]
        return torch.stack(
            tensors
        ).view(-1, self.dimension)
    
    
    def get_average_x(self):
        X = self.collect_nodes_param()
        average_x = X.mean(dim=0)
        return average_x
    
    def assign_nodes_param(self, matrix):
        for i, node in enumerate(self.node_list):
            node.x = matrix[i].clone()
    
    def get_probability(self):
        return torch.rand(1)

    def step(self):
        pass

    def compute_loss(self, norm_x=None):
        if norm_x is None:
            average_x = self.get_average_x()
            return compute_loss(average_x, self.data, self.label, self.is_convex)
        else:
            return compute_loss(norm_x, self.data, self.label, self.is_convex)
    
    def test_consensus(self):
        X = self.collect_nodes_param()
        average_x = X.mean(dim=0, keepdim=True)
        return (X - average_x).norm()
        
class DeFW_Server(BaseServer):
    def __init__(self, data: torch.Tensor, label: torch.Tensor, scale: float, n_nodes: int, gossip_matrix: torch.Tensor, batch_size: int = 1, base_lr: float = 1, is_convex: bool = True) -> None:
        super().__init__(data, label, scale, n_nodes, gossip_matrix, batch_size, base_lr, is_convex)
        self.node_list = self.init_nodes()
    
    def init_nodes(self):
        idx_split = self.split_data()

        nodes = [DeFW_Node(
            self.data[idx_split[i]],
            self.label[idx_split[i]],
            self.scale,
            self.base_lr,
            self.is_convex,
        ) for i in range(self.n_nodes)]
        
        return nodes
    
    def collect_nodes_grad(self):
        return torch.cat(
            [node.get_grad(node.data, node.label) for node in self.node_list]
        ).view(self.n_nodes, -1)
    
    def update_nodes_param(self, grads):
        for i, node in enumerate(self.node_list):
            node.step(grads[i])

    def step(self):
        X = self.collect_nodes_param()
        new_X = self.gm @ X

        self.assign_nodes_param(new_X)

        G = self.collect_nodes_grad()

        self.update_nodes_param(self.gm @ G)


        self.ifo_num = self.node_list[0].ifo_num

class DsgFw_Server(BaseServer):
    def __init__(
        self,
        data: torch.Tensor,
        label: torch.Tensor,
        scale: float,
        n_nodes: int,
        gossip_matrix: torch.Tensor,
        batch_size: int = 1,
        base_lr: float = 1.0,
        is_convex: bool = True
    ) -> None:
        """Basic server class for each node in decentralized learning

        Args:
            data (torch.Tensor): x with the shape (m, d)
            label (torch.Tensor): y with the shape (m, )
            scale (float): the radio for the l1 region, i.e., lambda,
            n_nodes (int): the number of nodes in the diagram
            gossip_matrix (torch.Tensor): the gossip matrix for decentralized diagram
            rho (flaot): the threshold for the loopless variance reduction
            batch_size (int): batch size
            base_lr (float, optional): the basic learning rate
            is_convex (book, optional): if true, then for logitistic regression, else for non-convex problem
            K (int): the number of fast mix communication
        """
        super().__init__(data=data, label=label, scale=scale, n_nodes=n_nodes, gossip_matrix=gossip_matrix, batch_size=batch_size, base_lr=base_lr, is_convex=is_convex)

        self.node_list = self.init_nodes()
    
    def init_nodes(self):
        idx_split = self.split_data()

        nodes = [DsgFw_Node(
            self.data[idx_split[i]],
            self.label[idx_split[i]],
            self.scale,
            self.base_lr,
            self.is_convex,
        ) for i in range(self.n_nodes)]
        
        return nodes
    
    def collect_nodes_gt(self):
        tensors=[node.y for node in self.node_list]
        return torch.cat(tensors
        ).view(-1, self.dimension)  
    
    def assign_nodes_gt(self, matrix):
        for i, node in enumerate(self.node_list):
            node.y = matrix[i].clone()

    def step(self):
        X=self.collect_nodes_param()
        new_X=self.gm @ X
        for i, node in enumerate(self.node_list):
            node_new_X=new_X[i,:]
            node.local_update_param(node.y,node_new_X)
        for node in self.node_list:
            node.step()
        Y=self.collect_nodes_gt()
        Y=self.gm @ Y
        self.assign_nodes_gt(Y)
        self.ifo_num = self.node_list[0].ifo_num

class DVRGTFW_Server(BaseServer):
    def __init__(
        self,
        data: torch.Tensor,
        label: torch.Tensor,
        scale: float,
        n_nodes: int,
        gossip_matrix: torch.Tensor,
        rho: float,
        batch_size: int = 1,
        base_lr: float = 1.0,
        is_convex: bool = True,
    ) -> None:
        """Basic server class for each node in decentralized learning

        Args:
            data (torch.Tensor): x with the shape (m, d)
            label (torch.Tensor): y with the shape (m, )
            scale (float): the radio for the l1 region, i.e., lambda,
            n_nodes (int): the number of nodes in the diagram
            gossip_matrix (torch.Tensor): the gossip matrix for decentralized diagram
            rho (flaot): the threshold for the loopless variance reduction
            batch_size (int): batch size
            base_lr (float, optional): the basic learning rate
            is_convex (book, optional): if true, then for logitistic regression, else for non-convex problem
            K (int): the number of fast mix communication
        """
        super().__init__(data=data, label=label, scale=scale, n_nodes=n_nodes, gossip_matrix=gossip_matrix, batch_size=batch_size, base_lr=base_lr, is_convex=is_convex)
        self.rho = rho
        self.eta_miu = (1 - math.sqrt(1 - self.gm_lambda**2)) / (1 + math.sqrt(1 - self.gm_lambda**2))
        # self.K=get_K(self.gm_lambda)
        self.K=1
        self.node_list = self.init_nodes()
        Y = self.collect_nodes_gt()
        new_Y = self.fastmix(Y)
        self.assign_nodes_gt(new_Y)

    
    
    def init_nodes(self):
        idx_split = self.split_data()

        nodes = [DVRGTFW_Node(
            self.data[idx_split[i]],
            self.label[idx_split[i]],
            self.scale,
            self.rho,
            self.batch_size,
            self.base_lr,
            self.is_convex,
        ) for i in range(self.n_nodes)]
        

        return nodes
    
    
        
    def fastmix(self, g:torch.Tensor) -> torch.Tensor:
        """

        Args:
            g (torch.Tensor): the param or gradient

        Returns:
            torch.Tensor: the same shape with the g
        """
        g_prev = g.clone()
        g_next = g.clone()
        for _ in range(self.K+1):
            t = g_next.clone()
            g_next = (1 + self.eta_miu) * self.gm @ g_next - self.eta_miu * g_prev
            g_prev = t
        return g_next
    
    def collect_nodes_gt(self):
        return torch.cat(
            [node.y for node in self.node_list]
        ).view(-1, self.dimension)  # (n, d)
    
    def assign_nodes_gt(self, matrix):
        for i, node in enumerate(self.node_list):
            node.y = matrix[i].clone()

    def step(self):
        p = self.get_probability()

        # compute d and update x
        for node in self.node_list:
            node.local_update_param(node.y)
        
        X = self.collect_nodes_param()
        X = self.fastmix(X)

        self.assign_nodes_param(X)

        for node in self.node_list:
            node.step(p)
        
        Y = self.collect_nodes_gt()
        Y = self.fastmix(Y)
        self.assign_nodes_gt(Y)
        self.ifo_num = self.node_list[0].ifo_num