from .comlib import *
from . import compute_grad, convert_vec_tuple, model_setup

def optimizer_generator(name, parameter_tuple,lr, maximize=False,**kwargs):
    if name=="Adam":
        return optim.Adam(parameter_tuple, lr=lr, maximize=maximize,**kwargs)
    if name=="SGD":
        return optim.SGD(parameter_tuple, lr=lr, maximize=maximize,**kwargs)
class Optimizer():
    def __init__(self, name, modelSetup: model_setup.ModelSetup,lr, maximize=False,**kwargs):
        self.optimizer_name=name
        self.modelSetup=modelSetup
        self.parameter_tuple=modelSetup.parameter_tuple
        self.lr=lr
        self.maximize=maximize
        self.optimizer_kwargs=kwargs

        self.optimizer=optimizer_generator(name, self.parameter_tuple,lr, maximize=maximize,**kwargs)

    def stepByGradTuple(self,gradient_tuple,closure=None):
        target_tensors=self.parameter_tuple
        optimizer=self.optimizer
        for i in range(len(target_tensors)):
            target_tensors[i].grad=gradient_tuple[i]
        optimizer.step(closure)

    def stepByGradVec(self,gradient_vec,closure=None):
        gradient_tuple=convert_vec_tuple.unflatten_to_tuple_with_form_dict(gradient_vec, self.modelSetup.form_dict)
        self.stepByGradTuple(gradient_tuple,closure)

    def initialize_optimizer(self):
        self.optimizer=optimizer_generator(self.optimizer_name, self.parameter_tuple,self.lr, maximize=self.maximize,**self.optimizer_kwargs)

class OptimizerWithGradClip(Optimizer):
    def __init__(self, optimizer_name, modelSetup, lr, maximize=False,max_norm=1.0, **kwargs):
        super().__init__(optimizer_name, modelSetup, lr, maximize, **kwargs)
        self.max_norm=max_norm

    def stepByGradTuple(self,gradient_tuple,closure=None):
        target_tensors=self.parameter_tuple
        optimizer=self.optimizer
        for i in range(len(target_tensors)):
            target_tensors[i].grad=gradient_tuple[i]
        torch.nn.utils.clip_grad_norm_(target_tensors, self.max_norm)
        optimizer.step(closure)


def create_optimizer(name:Literal["default","gradClip"],**kwargs):
    if name=="default":
        return Optimizer(**kwargs)
    if name=="gradClip":
        return OptimizerWithGradClip(max_norm=kwargs.get("max_norm"),**kwargs)
    
    print("invalid name")