import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class Constraint(torch.optim.Optimizer):
    """
    first_step: gradient of objective 1, and log the grad,
    second_step: gradient of objective 2, and do something based on the logged gradient at step one
    closure: the objective 2 for second step
    """
    def __init__(self, params, base_optimizer, g_star=0.05, alpha=1., beta=1., **kwargs):
        defaults = dict(g_star=g_star, **kwargs)
        super(Constraint, self).__init__(params, defaults)

        self.base_optimizer = base_optimizer
        self.param_groups = self.base_optimizer.param_groups
        self.g_star = g_star
        self.alpha = alpha
        self.beta = beta
        self.g_constraint = .0
        self.grad_inner = 0.
        

    def adjust_gradient(alpha, beta, ratio):
        """
        Adjust the gradient vector alpha to have a magnitude at least twice that of beta, without changing its direction.

        Parameters:
        alpha (torch.Tensor): The gradient vector to be adjusted.
        beta (torch.Tensor): The reference gradient vector.
        ratio (float): The scale ratio

        Returns:
        torch.Tensor: The adjusted gradient vector alpha.
        """
        # 计算alpha和beta的模长
        norm_alpha = alpha.norm()
        norm_beta = beta.norm()

        # 检查alpha的模长是否至少是beta模长的两倍
        if norm_alpha < ratio * norm_beta:
            # 计算需要增加的模长
            additional_norm = 2 * norm_beta - norm_alpha

            # 保持alpha的方向不变，增加模长
            # alpha是单位向量alpha_normalized乘以模长norm_alpha
            alpha_normalized = alpha / (1e-6 + norm_alpha)

            # 计算需要添加的向量，它与alpha同方向，大小为additional_norm
            additional_vector = alpha_normalized * additional_norm

            # 更新alpha，使其模长至少为2倍的beta模长
            alpha = alpha + additional_vector

        # 返回更新后的alpha
        return alpha
        
    @torch.no_grad()
    def first_step(self, zero_grad=False):
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None: continue

                # p.add_( - p.grad * 1e-3) # SGD learning rate
                constraint_grad = torch.ones_like(p.grad) * p.grad # deepcopy, otherwise the c_grad would be a pointer
                self.state[p]["constraint_grad"] = constraint_grad

                # 保存 forget set 计算的梯度的 L2 范数
                if "constraint_grad_norm" in self.state[p].keys():
                    self.state[p]["constraint_grad_norm"] = 1. * constraint_grad.norm() + .0 * self.state[p]["constraint_grad_norm"]
                else:
                    self.state[p]["constraint_grad_norm"] = constraint_grad.norm()

                # print('first cycle', self.state[p].keys())
        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def second_step(self, mode='one', zero_grad=False, step=True):
        '''
        calculate the projection here
        '''
        for group in self.param_groups:
            for p in group["params"]:

                if p.grad is None:
                    if "constraint_grad" in self.state[p].keys():
                        p.grad = self.state[p]["constraint_grad"]
                        pirnt('error')
                    else:
                        continue

                if "constraint_grad" not in self.state[p].keys():
                    continue

                # 计算 retain set 对应梯度和 first step 中 forget set 对应梯度的内积
                if "grad_inner" not in self.state[p].keys():
                    self.state[p]["grad_inner"] = (p.grad * self.state[p]["constraint_grad"]).sum()
                else:
                    self.state[p]["grad_inner"] = .0 * self.state[p]["grad_inner"] + 1. * (p.grad * self.state[p]["constraint_grad"]).sum()

                epsilon = 1e-8
                delta = 2
                self.alpha = 1
                self.beta = 1
                if mode == 'one':
                    # p.grad = self.adjust_gradient(p.grad, self.state[p]["constraint_grad"], 1.0)
                    
                    # 用 flag 来表示是否使用正交
                    flag = False
                    if flag:
                        norm_retain = p.grad.norm()
                        norm_forget = self.state[p]["constraint_grad"].norm()

                        # cos_theta between g_forget and g_retain
                        dot_product = torch.dot(p.grad.view(-1), self.state[p]["constraint_grad"].view(-1))
                        cosine_of_angle = dot_product / (norm_retain * norm_forget + epsilon)

                        # compute phi_t
                        # phi_t = torch.ones_like(p.grad) * (self.beta * torch.norm(self.state[p]["constraint_grad"]) ** 2)
                        phi_t = (norm_retain * (cosine_of_angle.pow(2) - 1)) / (norm_forget * cosine_of_angle + epsilon)
                    else:
                        phi_t = torch.ones_like(p.grad) * (self.beta * torch.norm(self.state[p]["constraint_grad"]) ** delta)
                    
                    # 当 phi_t 按照上式来取值时，下式可以优化
                    # lambda_t = max( self.beta - self.state[p]["grad_inner"] / ( 1e-6 + self.state[p]["constraint_grad"].norm().pow(2) ), 0)
                    # lambda_t = max( (phi_t - self.state[p]["grad_inner"]) / (torch.norm(self.state[p]["constraint_grad"]) ** 2 + epsilon), 0)
                    lambda_t = (phi_t - self.state[p]["grad_inner"]) / (torch.norm(self.state[p]["constraint_grad"]) ** 2 + epsilon)
                    lambda_t = torch.clamp(lambda_t, min=0)
                    
                    # adaptive_step_t = torch.clamp( lambda_t, min=0., max=2.)
                    p.grad.add_( lambda_t * self.state[p]["constraint_grad"])
                    
                elif mode == 'all':
                    flag = True
                    if flag:
                        temp_phi = self.alpha * (self.g_value - self.g_constraint) ** delta
                        phi_t = torch.ones_like(p.grad) * (self.beta * torch.norm(self.state[p]["constraint_grad"]) ** 2) * temp_phi
                    else:
                        phi_t = torch.ones_like(p.grad) * self.alpha * (self.g_value - self.g_constraint) ** delta
                    lambda_t = (phi_t - self.state[p]["grad_inner"]) / (torch.norm(self.state[p]["constraint_grad"]) ** 2 + epsilon)
                    lambda_t = torch.clamp(lambda_t, min=0)
                    p.grad.add_( lambda_t * self.state[p]["constraint_grad"])
                    
                # d_t 是由一个优化问题（凸二次规划）所决定，其解为：d_t = g_r + lambda * g_f
                # 其中 lambda = max(1-inner(g_r, g_f)/L_2(g_f)^2,0)
                
                # 如果 g_value 大于所设约束，则此时 lambda = 0，d_t 就等于 g_r + 0
                    
                '''
                if self.g_value > self.g_constraint:
                    adaptive_step_x = 0.
                    p.grad.add_( self.state[p]["constraint_grad"] * adaptive_step_x + self.state[p]["constraint_grad"])
                else:
                    adaptive_step_x = self.state[p]["grad_inner"] / ( 1e-6 + p.grad.norm().pow(2) )
                    adaptive_step_x = torch.clamp(- adaptive_step_x, min=0., max=2.)

                    p.grad.add_( p.grad * adaptive_step_x + self.state[p]["constraint_grad"])
                '''
        self.base_optimizer.step()

        if zero_grad: self.zero_grad()


    @torch.no_grad()
    def step(self, forget_closure=None, retain_closure=None, mode='one', g_constraint=None):
        assert forget_closure is not None, "Requires forget_closure, but it was not provided, raise an error"
        assert retain_closure is not None, "Requires retain_closure but it was not provided, raise an error"
        # assert g_value is not None, "Requires g value"
        assert g_constraint is not None, "Requires g constraint"
                    
        forget_closure = torch.enable_grad()(forget_closure)  # the closure should do a full forward-backward pass
        retain_closure = torch.enable_grad()(retain_closure)  # the closure should do a full forward-backward pass
            
        # self.g_value = g_value
        self.g_constraint = g_constraint

        if mode == 'one':
            forget_loss = forget_closure()
            self.first_step(zero_grad=True)
            retain_loss = retain_closure()
            self.second_step(mode=mode, zero_grad=True)
        elif mode == 'all':
            forget_loss = forget_closure()
            self.g_value = forget_loss
            self.first_step(zero_grad=True)
            retain_loss = retain_closure()
            self.second_step(mode=mode, zero_grad=True)
        return forget_loss, retain_loss
        

            
