import numpy as np
import torch
from typing import Union, Optional

class Optimizer:
    def __init__(self, solution_shape, dtype:torch.dtype, device):
        self.shape = solution_shape
        self.dtype = dtype
        self.device = device
        self.t = 0

    def ascent(self, gradient: torch.Tensor) -> torch.Tensor:

        if gradient.shape != self.shape:
            raise ValueError(
                "This optimizer was initialized for"
                f" a tensor of shape {self.shape},"
                " but the gradient provided has an incompatible shape:"
                f" {gradient.shape}"
            )

        self.t += 1

        return self._compute_step(gradient)
    
    def descend(self, gradient: torch.Tensor) -> torch.Tensor:

        if gradient.shape != self.shape:
            raise ValueError(
                "This optimizer was initialized for"
                f" a tensor of shape {self.shape},"
                " but the gradient provided has an incompatible shape:"
                f" {gradient.shape}"
            )

        self.t += 1

        return self._compute_step(-gradient)

    def _compute_step(self, globalg):
        raise NotImplementedError
    
    def state_dict(self):
        return {'t':self.t}
    
    def load_state_dict(self,sd):
        self.t = sd['t']

    def to(self,device):
        self.device=device
        return self
    
class Adam(Optimizer):

    def __init__(self,
                 solution_shape,
                 dtype:torch.dtype,
                 device,
                 stepsize: float,
                 beta1: float=0.9,
                 beta2: float=0.999,
                 epsilon: float=1e-08):
        super().__init__(solution_shape=solution_shape, dtype=dtype, device=device)
        self.stepsize = float(stepsize)
        self.beta1 = float(beta1)
        self.beta2 = float(beta2)
        self.epsilon = float(epsilon)
        self.m:torch.Tensor = torch.zeros(self.shape, dtype=self.dtype, device=self.device)
        self.v:torch.Tensor = torch.zeros(self.shape, dtype=self.dtype, device=self.device)

    def _compute_step(self, gradient):
        a = self.stepsize * np.sqrt(1 - self.beta2 ** self.t) / (1 - self.beta1 ** self.t)
        self.m = self.beta1 * self.m + (1 - self.beta1) * gradient
        self.v = self.beta2 * self.v + (1 - self.beta2) * (gradient * gradient)
        step = a * self.m / (self.v.sqrt() + self.epsilon)
        return step

    def state_dict(self):
        return super().state_dict().update({
            'm':self.m,
            'v':self.v,
        })
    
    def load_state_dict(self,sd):
        self.t = sd['t']
        assert self.m.shape == sd['m'].shape, f"Provided shape ({sd['m'].shape}) does not match with existing shape ({self.m.shape})"
        assert self.v.shape == sd['v'].shape, f"Provided shape ({sd['v'].shape}) does not match with existing shape ({self.v.shape})"
        self.m = sd['m']
        self.v = sd['v']

    def to(self, device):
        self.m = self.m.to(device)
        self.v = self.v.to(device)
        return super().to(device)

class ClipUp(Optimizer):
    # This code is based on OpenAI's SGD class.
    # It works like SGD, but also
    # clips the velocity and the step to be taken.

    @staticmethod
    def clip(x:torch.Tensor, max_length: float):
        length = x.norm(2)
        if length > max_length:
            ratio = max_length / length
            return x * ratio
        else:
            return x

    def __init__(self,
                 solution_shape,
                 dtype: torch.dtype,
                 device,
                 stepsize: float,
                 momentum: float=0.9,
                 max_speed: float=2,
                 fix_gradient_size: bool=True):
        super().__init__(solution_shape=solution_shape, dtype=dtype, device=device)
        self.v = torch.zeros(self.shape, dtype=self.dtype, device=self.device)
        self.stepsize = float(stepsize)
        self.momentum = float(momentum)
        self.max_speed = float(max_speed)
        self.fix_gradient_size = bool(fix_gradient_size)

    def _compute_step(self, gradient:torch.Tensor):
        if self.fix_gradient_size:
            g_len = gradient.norm(2)
            gradient = gradient / g_len

        step = gradient * self.stepsize

        self.v = self.momentum * self.v + step
        self.v = self.clip(self.v, self.max_speed * self.stepsize)

        return self.v

    def state_dict(self):
        return super().state_dict().update({
            'v':self.v,
        })
    
    def load_state_dict(self,sd):
        self.t = sd['t']
        assert self.v.shape == sd['v'].shape, f"Provided shape ({sd['v'].shape}) does not match with existing shape ({self.v.shape})"
        self.v = sd['v']

    def to(self, device):
        self.v = self.v.to(device)
        return super().to(device)
