import torch
from torch.optim import Optimizer
from torch import nn

import numpy as np
rng = np.random.default_rng()
from scipy import linalg
import math

from objectives import obj_fun

   
class ProxSkip:
    def __init__(self, step_sz, x_0, prox):
        self.step_sz = step_sz
        self.x_0 = x_0.copy()
        self.xs = []
        self.prox = prox
        self.error = []
        self.sparsity = []
        self.objFunc = []
        self.lya = []
        self.lya_start = 0

    def solve(self, grad_est, d, p, stop="comm", comms=None, eps=None, x_sol=None, data_collection=False):
        if not (stop in ["comm", "eps"]):
            raise ValueError("ScaffNew | stop must be either comm or eps.") 
        if stop == "comm" and comms is None:
            raise ValueError("ScaffNew | When stopping on comms, comms must be provided.")
        elif stop == "eps" and (eps is None or x_sol is None):
            raise ValueError("ScaffNew | If stop is eps, then eps and x_sol must be provided.")
        elif data_collection and x_sol is None:
            raise ValueError("ScaffNew | x_sol must be given for data collection.")

        self.communication_rounds = 0
        h = np.zeros((d))
        if data_collection:
            h_opt = grad_est.grad(x_sol)
            self.lya_start = linalg.norm( self.x_0 - x_sol)**2 + self.step_sz**2/p**2*linalg.norm(h - h_opt)**2
        x = self.x_0.copy()

        A = grad_est.A
        b = grad_est.b
        mu = grad_est.mu

        while (stop == "comm" and self.communication_rounds < comms) or \
            (stop == "eps" and self.eps(x, x_sol, self.x_0) > eps):
            g = grad_est.grad(x)
            x = x - self.step_sz*(g - h)
            if rng.binomial(1,p):
                hat_x = x.copy()
                x = self.prox(hat_x - self.step_sz/p*h, self.step_sz/p)
                h += p/self.step_sz*(x - hat_x)
                self.communication_rounds += 1
                if data_collection:  
                    self.xs.append(x)
                    self.error.append(linalg.norm(x - x_sol)**2)
                    self.sparsity.append(linalg.norm(x,0))
                    self.objFunc.append(obj_fun(A,b,x,mu))
                    self.lya.append(linalg.norm( x - x_sol)**2 + self.step_sz**2/p**2*linalg.norm(h - h_opt)**2)
        return x
    
    def eps(self, x, x_sol,x_0):
        return linalg.norm(x_sol - x)**2 /  linalg.norm(x_sol - x_0)**2

class ProxSkipClient(Optimizer):
    def __init__(self, named_params, client_state, p=1, lr=1e-3, weight_decay=0, dual_lr=1):
        self.client_state = client_state
        self.weight_decay = weight_decay
        if isinstance(named_params, dict):
            return NotImplementedError
        self.param_names, params = zip(*named_params)
        self.dual_lr = dual_lr
        super().__init__(params, {'lr':lr, 'p':p})#, 'sparsity':sparsity})

    def update_control(self, actual_steps, global_mod_sparsity=0):
        for group in self.param_groups:
            for i, par in enumerate(group['params']):
                param_name = self.param_names[i]
                step_size = group['lr']

                with torch.no_grad():
                    self.client_state['control'][param_name].sub_(self.client_state['par_hat'][param_name].sub_(par.data), alpha=self.dual_lr/step_size/actual_steps)    
                        
                    if global_mod_sparsity > 0:
                        raise NotImplementedError("Subtle bug hasn't been fixed here.")
                        top_k_unstructured(self.param_groups, global_mod_sparsity)

    def step(self, closure=None, save_par_hat=False, sparsity=None, global_mod_sparsity=0, actual_steps=None):
        with torch.no_grad():

            for group in self.param_groups:
                for i, par in enumerate(group['params']):
                    if par.grad is None:
                        continue
                    
                    step_size = group['lr']
                    p = group['p']
                    param_name = self.param_names[i]

                    grad = par.grad
                    
                    if self.weight_decay > 0.0:
                        # print(self.weight_decay)
                        grad.add_(par, alpha=self.weight_decay)
                    
                    if param_name not in self.client_state['control']:
                        self.client_state['control'][param_name] = torch.zeros_like(grad, requires_grad=False)
                    elif actual_steps:
                        # par_hat contains the parameters before communication
                        # par.data contains the parameters after communication
                        # difference  = par.data.sub(self.client_state['par_hat'][param_name])
                        # self.client_state['control'][param_name].add_(difference, alpha=1/step_size/actual_steps)
                        
                        # Optimized in place statement
                        self.client_state['control'][param_name].sub_(self.client_state['par_hat'][param_name].sub_(par.data), alpha=self.dual_lr/step_size/actual_steps)    
                        
                        if global_mod_sparsity > 0:
                            top_k_unstructured(self.param_groups, global_mod_sparsity)
                
                    par.sub_(grad.sub(self.client_state['control'][param_name]), alpha=step_size)
        
            if sparsity:
                top_k_unstructured(self.param_groups, sparsity)

            for group in self.param_groups:
                for i, par in enumerate(group['params']):
                    if par.grad is None:
                        continue
            
                    param_name = self.param_names[i]
                    if save_par_hat:
                        self.client_state['par_hat'][param_name] = par.data.detach().clone()

def top_k_unstructured(params, sparsity, groups=True):
    with torch.no_grad():
        weights = [] # list of tensors to perform operations

        # Hacky workaround if this function is called outside the optimizer
        param_groups = params if groups else [{'params':list(params)}]

        for group in param_groups:
            for par in group['params']:
                weights.append(par)
        
        shapes = [weight.shape for weight in weights]
        split_indices = tuple(map(lambda x: x.numel(), weights))
        total_size = int(sum(split_indices))
        weights = [weight.flatten() for weight in weights]

        num_components_to_sparsify = int((sparsity)*total_size)
        num_components_to_sparsify = min(total_size -1, num_components_to_sparsify)
        num_components_to_sparsify = max(0, num_components_to_sparsify)

        weights = torch.cat(weights)
        _, indices = torch.topk(torch.abs(weights), num_components_to_sparsify, largest=False)
        weights[indices] = 0
        weights = torch.split(weights, split_indices)
        weights = [arr.reshape(shape) for arr, shape in zip(weights, shapes)]

        i = 0
        for group in param_groups:
            for par in group['params']:
                par.data.copy_(weights[i])
                i += 1

        return weights
