import torch
import numpy.linalg as linalg
import numpy as np
import math
from collections import defaultdict

def get_solution(J, ab):
    u_star  = - torch.linalg.inv(J) @ ab
    return torch.tensor([0.0, 0.0])# u_star


class Algorithm:
    def __init__(self, L_0, L_1, function, gradient_fun, solution, lr_0=None, init=None):
        self.lr_0= lr_0
        self.lr = None
        if init == None:
            self.u =  torch.randn_like(solution)
            self.u =  5 * self.u / torch.norm(self.u)
        else:
            self.u, self.h = init
        self.L_0 = L_0
        self.L_1 = L_1
        self.function = function
        self.solution = solution
        self.f_star = self.function(solution)
        self.results = defaultdict(list)
        self.gradient_fun = gradient_fun
        self.initial_distance = torch.norm(self.solution - self.u)
        # self.h_history = [self.h, self.h]
        self.u_history = [self.u, self.u]
        # self.Fh_history = [self.gradient_p(self.h, None), self.gradient_p(self.h, None)]
        self.Fu_history = [self.gradient(self.u), self.gradient(self.u)]

    def update(self):
        raise NotImplementedError

    def run(self, n_steps=400):
        if self.lr_type == "normalized":
            # self.beta = [self.initial_distance / math.sqrt(n_steps +1) for i in range(n_steps)]
            self.beta = [self.initial_distance  / math.sqrt(i + 1) for i in range( n_steps)]
        self.results["u"].append(self.u)
        self.results["Dist2Sol"].append(float(torch.norm(self.solution - self.u)**2))
        self.results["func_residual"].append(float(self.function(self.u) -self.f_star ))
        for i in range(n_steps):
            self.update(i)
            self.results["u"].append(self.u)
            self.results["Dist2Sol"].append(float(torch.norm(self.solution - self.u)**2))
            self.results["func_residual"].append(float(self.function(self.u) -self.f_star ))
        return self.results
    
    def gradient(self, x):
        if self.gradient_fun == None:
            print("Implement gradient function")
            raise NotImplementedError
        else:
            return self.gradient_fun(x)

class GD(Algorithm):
    def __init__(self, *args, lr_type="first_type", beta = None, **kwargs):
        super(GD, self).__init__(*args, **kwargs)
        self.lr_type = lr_type
        if beta is not None:
            self.beta = beta
        else:
            self.beta = beta

    def update(self, i):
        grad = self.gradient(self.u)
        grad_norm = torch.norm(grad)
        if self.lr_type == "polyak":
            square_grad_norm = grad_norm * grad_norm
            f_star_difference = self.function(self.u) - self.f_star
            step_size =  f_star_difference/ square_grad_norm
        elif self.lr_type == "first_type":
            inside_log = 1 + self.L_1 * grad_norm /(self.L_0 + self.L_1 * grad_norm)
            step_size = torch.log(inside_log) / (self.L_1 * grad_norm)
        elif self.lr_type == "second_type":
            step_size = 1 / (self.L_0 + (3/2) * self.L_1 * grad_norm)
        elif self.lr_type == "normalized":
            step_size = self.beta[i] / grad_norm
        elif self.lr_type == "clipped":
            step_size = min(1/(2*self.L_0), 1/(3 * self.L_1 * grad_norm))
        # if i < 10:
        #     print(step_size)
        self.results["grads"].append(grad) 
        self.results["grad_norm"].append(grad_norm)
        self.u = self.u - step_size * grad


