import logging
import pickle
from typing import Tuple
import math
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset

logger = logging.getLogger(__name__)
def euclidean_proj_simplex(v, s=1):

    assert s > 0, "Radius s must be strictly positive (%d <= 0)" % s
    n, = v.shape
    # check if we are already on the simplex
    if v.sum() == s and np.alltrue(v >= 0):
     # best projection: itself!
         return v
     # get the array of cumulative sums of a sorted (decreasing) copy of v
    u = np.sort(v)[::-1]
    cssv = np.cumsum(u)

    rho = np.nonzero(u * np.arange(1, n+1) > (cssv - s))[0][-1]
    theta = float(cssv[rho] - s) / rho

    w = (v - theta).clip(min=0)
    return w
    
def euclidean_proj_l1ball(v, s=1):
    """ Compute the Euclidean projection on a L1-ball
    Solves the optimisation problem (using the algorithm from [1]):
        min_w 0.5 * || w - v ||_2^2 , s.t. || w ||_1 <= s
    Parameters
    ----------
    v: (n,) numpy array,
       n-dimensional vector to project
    s: int, optional, default: 1,
       radius of the L1-ball
    Returns
    -------
    w: (n,) numpy array,
       Euclidean projection of v on the L1-ball of radius s
    Notes
    -----
    Solves the problem by a reduction to the positive simplex case
    See also
    --------
    euclidean_proj_simplex
    """
    assert s > 0, "Radius s must be strictly positive (%d <= 0)" % s
    n, = v.shape  # will raise ValueError if v is not 1-D
    # compute the vector of absolute values
    # u = np.abs(v)
    u = v.abs()
    # check if v is already a solution
    if u.sum() <= s:
        # L1-norm is <= s
        return v
    # v is not already a solution: optimum lies on the boundary (norm == s)
    # project *u* on the simplex
    w = euclidean_proj_simplex(u.cpu().numpy(), s=s)
    # compute the solution to the original problem on v
    return torch.tensor(w, device=v.device) * torch.sign(v)

        
def compute_avg_grad(
    param: torch.Tensor, x: torch.Tensor, y: torch.Tensor, is_convex: bool = True,
) -> torch.Tensor:
    if is_convex:
        numerator = -y* torch.exp(-y*(x @ param))
        denominator = 1 + torch.exp(-y* (x @ param))
        return ((numerator / denominator).unsqueeze(1) * x).mean(dim=0)
    else:
        numerator = -y* torch.exp(-y*(x @ param))
        denominator = (1 + torch.exp(-y* (x @ param))).pow(2)
        return ((numerator / denominator).unsqueeze(1) * x).mean(dim=0)

def LMO(grad, scale):
    idx = torch.argmax(grad.abs())
    v = torch.zeros_like(grad)
    v[idx] = scale * torch.sign(grad[idx])
    return -v


def compute_loss(param: torch.Tensor, x: torch.Tensor, y: torch.Tensor, is_convex:bool = True) -> torch.Tensor:
    """
    Compute the loss for logistic regression problem.

    @param param: variable of the loss function.
    @param x: the input data with shape (batch, dimension) or (dimension, ).
    @param y: the label.
    """

    if is_convex:
        param.squeeze_()
        return torch.log(1 + torch.exp(-y* (x @ param))).mean()  # logistic loss
    else:
        param.squeeze_()
        return (1/(1 + torch.exp(y* (x @ param)))).mean()

class CustomDataset(Dataset):
    def __init__(self, data: torch.Tensor, label: torch.Tensor) -> None:
        self.data = data
        self.label = label

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        return (self.data[idx], self.label[idx])

    def __len__(
        self,
    ) -> int:
        return len(self.label)
        
class BaseNode:
    def __init__(
        self,
        data: torch.Tensor,
        label: torch.Tensor,
        scale: float,
        base_lr: float = 1.0,
        is_convex: bool = True,
    ) -> None:
        """Basic Class for each node in decentralized learning

        Args:
            data (torch.Tensor): x with the shape (m, d)
            label (torch.Tensor): y with the shape (m, )
            scale (float): the radio for the l1 region, i.e., lambda
            base_lr (float, optional): the basic learning rate
            is_convex (book, optional): if true, then for logitistic regression, else for non-convex problem
        """
        self.data = data
        self.label = label
        self.scale = scale
        self.base_lr = base_lr
        self.is_convex = is_convex

        self.num_sample = data.shape[0]
        self.dimension = data.shape[1]

        self.x = torch.zeros((self.dimension,), device=self.data.device)
        self.iter_num = 0
        self.ifo_num = 0

    def get_gamma(self):
        pass

    def step(self):
        pass

    def get_grad(self, data, label):
        return compute_avg_grad(self.x, data, label, is_convex=self.is_convex)

    def local_update_param(self, grad):
        v = LMO(grad=grad, scale=self.scale)
        # d = v - self.x

        gamma = self.get_gamma()
        self.iter_num += 1
         
        self.x = (1 - gamma) * self.x + gamma * v

        if self.x.abs().sum() > self.scale:
            self.x = euclidean_proj_l1ball(self.x, self.scale)
        
        return self.x.clone()

    def loss(
        self,
    ) -> torch.Tensor:
        return compute_loss(param=self.x, x=self.data, y=self.label, is_convex=self.is_convex)
        
class DeFW_Node(BaseNode):
    def __init__(
        self,
        data: torch.Tensor,
        label: torch.Tensor,
        scale: float,
        base_lr: float,
        is_convex: bool = True,
    ) -> None:
        super().__init__(data, label, scale, base_lr, is_convex)


    
    def get_gamma(self):
        if self.is_convex:
            return self.base_lr*2/(self.iter_num+1)
        else:
            return self.base_lr*1/math.sqrt(self.iter_num+1)
    
    def step(self,grad):
        self.local_update_param(grad)
        
        self.ifo_num += self.num_sample



class DsgFw_Node(BaseNode):

    def __init__(
        self,
        data: torch.Tensor,
        label: torch.Tensor,
        scale: float,
        base_lr: float,
        is_convex: bool = True,
    ) -> None:
        super().__init__(data, label, scale, base_lr, is_convex)
        self.data_set = CustomDataset(self.data, self.label)
        self.init_variables()
        self.prev_x = self.x.clone()
        
    def get_batch_size(self, q):
        if((self.iter_num+1)%q==0):
            # print(f"the batch size is {int(q**2)} and q is {q}")
            return int(q**2)
        if self.is_convex:
            a=int(self.iter_num/q)
            b=(a+1)*q
            c=(b**2)/(self.iter_num**2)
            # print(f"the batch size is {int(q**2*c)}, q is {q}, a is {a}, b is {b}, and c is {c}")
            return int(q**2*c)
        else:
            a=int(self.iter_num/q)
            b=(a+1)*q
            c=b/self.iter_num
            # print(f"the batch size is {int(q**2*c)}, q is {q}, a is {a}, b is {b}, and c is {c}")
            return int(q**2*c)    

    def get_gamma(self):
        if self.is_convex:
            return self.base_lr*2/(self.iter_num+1)
        else:
            return self.base_lr*1/math.sqrt(self.iter_num+1)
        
    def set_q(self):
        # print(f"the num_sample is {self.num_sample}, the length of data is {len(self.data)}")
        if self.is_convex:
            return int(self.num_sample**0.25)
        else:
            return int(self.num_sample**0.33)

    def init_variables(self):
        self.v=self.get_grad(self.data, self.label)
        self.ifo_num+=self.num_sample
        self.y=self.v.clone()

    def local_update_param(self, grad, new_X):
        v=LMO(grad=grad, scale=self.scale)    
        gamma=self.get_gamma()
        self.iter_num+=1
        self.prev_x=self.x.clone()
        self.x=(1-gamma)*new_X+gamma*v
        if self.x.abs().sum()>self.scale:
            self.x=euclidean_proj_l1ball(self.x, self.scale)
        return self.x.clone()    

    def get_batch(self) -> Tuple[torch.Tensor, torch.Tensor, bool]:
        q=self.set_q()
        batch_size=self.get_batch_size(q)
        
        random_idx_list = torch.randperm(self.num_sample)
        return self.data[random_idx_list[:batch_size], :], self.label[random_idx_list[:batch_size]]

        # data_loader = DataLoader(self.data_set, batch_size, shuffle=True)
        # iter_data = iter(data_loader)
        # flag=False
        # try:
        #     data, label=next(iter_data)
        # except StopIteration:
        #     self.iter_data=iter(data_loader)
        #     data,label=next(iter_data)
        #     flag=True
        # return data, label, flag
    
    def step(self):
        q=self.set_q()
        if (self.iter_num+1)%q==0:
            prev_v=self.v.clone()
            grad=self.get_grad(self.data, self.label)
            self.v=grad.clone()
            self.ifo_num+=self.num_sample
        else:
            batch_data, batch_label =self.get_batch()
            batch_length=len(batch_data)
            grad = compute_avg_grad(param=self.x, x=batch_data, y=batch_label, is_convex=self.is_convex)
            prev_grad = compute_avg_grad(param=self.prev_x, x=batch_data, y=batch_label, is_convex=self.is_convex)
            prev_v=self.v.clone()
            self.update_local_v(grad, prev_grad)
            self.ifo_num+=2*batch_length
        self.update_local_tracking(self.v, prev_v)

    def update_local_v(self, cur_grad, prev_grad):
        self.v = self.v + cur_grad - prev_grad
        return self.v
    
    def update_local_tracking(self, cur_v, prev_v):
        self.y = self.y + cur_v - prev_v


class DVRGTFW_Node(BaseNode):
     
    def __init__(
        self,
        data: torch.Tensor,
        label: torch.Tensor,
        scale: float,
        rho: float,
        batch_size: int,
        base_lr: float,
        is_convex: bool = True,
    ) -> None:
        super().__init__(
            data=data, label=label, scale=scale, base_lr=base_lr, is_convex=is_convex
        )


        self.data_set = CustomDataset(self.data, self.label)
        self.init_variables()
        self.data_loader = DataLoader(self.data_set, self.batch_size, shuffle=True)
        self.iter_data = iter(self.data_loader)
        self.base_lr=base_lr

        self.prev_x = self.x.clone()
    
    def init_variables(self):
        self.v = self.get_grad(self.data, self.label)
        self.prev_v=self.v.clone()
        self.batch_size=int(2*math.sqrt(self.num_sample/100))
        self.ifo_num += self.num_sample
        self.rho = 2*self.batch_size/(2*self.batch_size+self.num_sample)
        self.y = self.v.clone()

    def local_update_param(self, grad):
        v = LMO(grad=grad, scale=self.scale)
        # d = v - self.x

        gamma = self.get_gamma()
        self.iter_num += 1
        # update the previous x
        self.prev_x = self.x.clone()

        self.x = (1-gamma) * self.x + gamma * v

        if self.x.abs().sum() > self.scale:
            self.x = euclidean_proj_l1ball(self.x, self.scale)
        
        return self.x.clone()

       
    def get_gamma(self):
        if self.is_convex:
            return self.base_lr*2/(self.iter_num+1)
        else:
            return self.base_lr*1/math.sqrt(self.iter_num+1)    
   

    

    def get_batch(self) -> Tuple[torch.Tensor, torch.Tensor, bool]:
        flag = False
        try:
            data, label = next(self.iter_data)
        except StopIteration:
            self.iter_data = iter(self.data_loader)
            data, label = next(self.iter_data)
            flag = True
        return data, label, flag


    def step(self, p):
        # print(self.batch_size)
        # print(p)
        # print(self.rho)
        if p < self.rho:
            self.prev_v=self.v.clone()
            grad = self.get_grad(self.data, self.label)
            self.v = grad.clone()
            self.ifo_num += self.num_sample
        else:

            batch_data, batch_label, _ = self.get_batch()
            batch_length = len(batch_data)

            grad = compute_avg_grad(param=self.x, x=batch_data, y=batch_label, is_convex=self.is_convex)
        
            prev_grad = compute_avg_grad(param=self.prev_x, x=batch_data, y=batch_label, is_convex=self.is_convex)

            # self.h.add_(partial_grad - prev_partial_grad)
            self.prev_v = self.v.clone()

            self.update_local_v(grad, prev_grad)
            
            self.ifo_num += 2 * batch_length

            
        self.update_local_tracking(self.v, self.prev_v)
    
    def update_local_v(self, cur_grad, prev_grad):
        self.v = self.v + cur_grad - prev_grad
        return self.v
    
    def update_local_tracking(self, cur_v, prev_v):
        self.y = self.y + cur_v - prev_v
