import torch
import numpy.linalg as linalg
import numpy as np
from collections import defaultdict


class Algorithm:
    def __init__(self, p=2.0, dim=2, noise_ratio=5.0, lr_0=None, init=None, noise=False, gradient_fun=None):
        self.num_samples = 1
        self.dim = dim
        self.lr_0 = lr_0
        self.noise_ratio = noise_ratio
        self.lr = None
        if init == None:
            self.u = 100 * torch.randn(self.dim) #_like(abs.mean(0))
            self.h = self.u + torch.randn_like(self.u)
        else:
            self.u, self.h = init
        self.alpha = (p-2.0)/(p-1.0)
        self.L_0 = 1 + (p-1.0)* 2**(0.5) * 4 ** (1/ (p -1.0))
        self.L_1 = 2 * (p-1) * 2** (1 / (2*(p - 1.0)))
        self.K_0 = self.L_0 * (2**((self.alpha * self.alpha)/(1 - self.alpha)) + 1)
        self.K_1 = self.L_1 * 2**((self.alpha * self.alpha)/(1 - self.alpha)) * 3**(self.alpha)
        self.K_2 = self.L_1**(1 / (1 - self.alpha)) * 2**((self.alpha * self.alpha)/(1 - self.alpha)) * 3**(self.alpha) * (1 - self.alpha)**(self.alpha/(1 - self.alpha))
        self.q = 1
        self.const = 2
        self.mu = 1
        self.mu =  2**(1.0 - p)
        self.L = 2
        self.p = p
        self.abs = abs
        self.transhold = 1
        self.solution = torch.zeros(2 * self.dim)
        print(self.solution.shape,self.solution)
        self.noise = noise
        self.results = defaultdict(list)
        self.dum = torch.tensor([[1.1, -1.0]])
        self.gradient_fun = gradient_fun
        self.h_history = [self.h, self.h]
        self.u_history = [self.u, self.u]
        self.a_k = 0.1
        self.Fh_history = [self.gradient_p(self.h, None), self.gradient_p(self.h, None)]
        self.Fu_history = [self.gradient_p(self.u, None), self.gradient_p(self.u, None)]
        self.scheduler = "constant"

    def update(self):
        raise NotImplementedError

    def run(self, n_steps=4000, sampler="uni", batch_size=1, scheduler="constant", q = 2/3, trashold=None):
        self.scheduler = scheduler
        self.q = q
        if scheduler == "q":
            if trashold is not None:
                self.transhold = trashold
            else:
                self.transhold = n_steps // 2
            self.lr = np.array([( 100 + self.lr_0 ) / (100 + k ** self.q) for k in range(1, n_steps + 1)])
        elif scheduler == "constant":
            self.lr = [ self.lr_0 for i in range(n_steps)]
        if sampler == "uni":
            idxs = [np.random.choice(self.num_samples, batch_size, replace=True) for i in range(n_steps)]
        else:
            idxs=[None for i in range(n_steps)]
        self.results["u"].append(self.u)
        self.results["Dist2Sol"].append(float(torch.norm(self.solution - self.u)))
        self.results["Dist2Sol-Avg"].append(float(torch.norm(self.solution - self.u)))
        sum_of_points = torch.zeros_like(self.u)
        sum_of_weigths = 0
        average_point = None
        for i in range(n_steps):
            self.update(i, idxs[i])
            self.results["u"].append(self.u)
            sum_of_points += self.lr[i] * self.u
            sum_of_weigths += self.lr[i]
            average_point = sum_of_points / sum_of_weigths
            self.results["Dist2Sol"].append(float(torch.norm(self.solution - self.u)))
            self.results["Dist2Sol-Avg"].append(float(torch.norm(self.solution - average_point)))
        return self.results

    def gradient_p(self, x, idx):
        midle = x.shape[0] //2
        tens = torch.cat([torch.pow(torch.norm(x[0: midle]), self.p - 2) * x[0: midle] + x[midle: ],\
             torch.pow(torch.norm(x[midle: ]), self.p - 2) * x[midle: ] -  x[0: midle]])
        
        if self.noise == True:
            return tens + self.noise_ratio * torch.randn_like(self.u) #.mean(0))
        else:
            return  tens

class SGD(Algorithm):
    def __init__(self, *args, same=False, **kwargs):
        super(SGD, self).__init__(*args, **kwargs)
        self.same_sample = same
        if self.lr_0 == None:
            self.lr_0 = self.mu / (20*self.L**2)
        self.a =  self.mu 
        self.different_sample = True
        self.d = (self.L**2) / self.mu

    def update(self, i, idx=None):
        if idx == None:
            idx = range(self.num_samples)
        grad = self.gradient_p(self.u, idx)
        if self.different_sample:
            grad_clip = self.gradient_p(self.u, idx)
            grad_norm = torch.norm(grad_clip)
        else:
            grad_norm = torch.norm(grad)
        self.results["grads"].append(grad) 
        self.results["grad_norm"].append(grad_norm)
        if self.scheduler == "golden":
            step_size = self.adaptive_stepsize(grad_norm)
        elif self.scheduler == "thm":
            step_size = self.adaptive_stepsize(grad_norm)
        elif self.scheduler == "q":
            step_size = self.lr[i] * min(1.0, 1.0/ grad_norm)
        self.u = self.u - step_size * grad

    def adaptive_stepsize(self, grad_norm):
        gamma_k = self.mu/(3* (self.K_0 * self.K_0 + self.K_1* self.K_1 + self.K_2 * self.K_2)) *  min([1, 1.0/(grad_norm)])
        return gamma_k

class Popov(Algorithm):
    def __init__(self, *args, **kwargs):
        super(Popov, self).__init__(*args, **kwargs)
        if self.lr_0 == None:
            self.lr_0 = 1/ (2 * (3**0.5) * self.L)
        self.a = self.mu / 4.0
        self.d = 2 * (3**0.5) * self.L
    def update(self, i, idx=None):
        if idx == None:
            idx = range(self.num_samples)
        grad = self.gradient_p(self.h, idx)
        grad_norm = torch.norm(grad)
        self.results["grads"].append(grad)
        self.results["grad_norm"].append(grad_norm)
        u_h_norm = torch.norm(self.u_history[-1] - self.h_history[-2])

        if self.scheduler == "golden":
            step_size = self.adaptive_stepsize(grad_norm, u_h_norm)
        elif self.scheduler == "thm":
            step_size = self.adaptive_stepsize(grad_norm, u_h_norm)
        elif self.scheduler == "q":
            step_size = self.lr[i] * min(1.0, 1.0/ grad_norm)
        self.u = self.u - step_size * grad
        u_h_norm = torch.norm(self.u - self.h_history[-1])
        if self.scheduler == "golden":
            step_size = self.adaptive_stepsize(grad_norm, u_h_norm)
        elif self.scheduler == "thm":
            step_size = self.adaptive_stepsize(grad_norm, u_h_norm)
        elif self.scheduler == "q":
            step_size = self.lr[i] * min(1.0, 1.0/ grad_norm)
        self.h = self.u - step_size *  grad
        self.h_history.append(self.h)
        
    def adaptive_stepsize(self, grad_norm, u_h_norm):
        gamma_k = min([1/(4*self.mu), 1/(6 * 2**(0.5)* self.K_0), 1/grad_norm, 1/(6 * 2**(0.5)* self.K_1* grad_norm**(self.alpha))])
        return gamma_k


class Extragradient(Algorithm):
    def __init__(self, *args, **kwargs):
        super(Extragradient, self).__init__(*args, **kwargs)
        if self.lr_0 == None:
            self.lr_0 = 1/ (10 * (3**0.5) * self.L)
        self.a = self.mu
        self.d = 2 * (3**0.5) * self.L

    def update(self, i, idx=None):
        if idx == None:
            idx = range(self.num_samples)
        grad_h = self.gradient_p(self.h, idx)
        self.results["grads"].append(grad_h) 
        grad_norm = torch.norm(grad_h)
        self.results["grad_norm"].append(grad_norm)
        ratio = 0.5 * torch.norm(self.u_history[-1] - self.h_history[-2]) / (torch.norm(self.Fu_history[-1] - self.Fh_history[-2]) + 0.00000001)
        self.a_k = min([self.a_k, ratio])
        gamma_k = self.adaptive_stepsize(grad_norm)
        if self.scheduler == "golden":
            step_size = self.a_k
        elif self.scheduler == "thm":
            step_size = gamma_k
        elif self.scheduler == "q":
            step_size = self.lr[i] * min(1.0, 1.0/ grad_norm)
        
        self.u = self.h - step_size * grad_h
        grad_u = self.gradient_p(self.u, idx)
        self.h = self.h - step_size * grad_u
        self.h_history.append(self.h)
        self.u_history.append(self.u)
        self.Fh_history.append(grad_h)
        self.u_history.append(grad_u)

    def adaptive_stepsize(self, grad_norm):
        gamma_k = min([1/self.mu, 1/(3 * 2**(0.5)* self.K_0), 1/grad_norm, 1/(3 * 2**(0.5)* self.K_1* grad_norm**(self.alpha)), 1/(3 * 2**(0.5)* self.K_2)])
        return gamma_k


