import torch
from torch.optim import Optimizer
from torch.nn.utils import parameters_to_vector, vector_to_parameters
from scipy.optimize import minimize
import numpy as np
import os
import cvxpy as cp
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"


#sfnc: simple function with no constrains, i.e., psi(x) = f(x), st. x in R^d.
#sflc: simple function with linear constrains, i.e. psi(x) = f(x), st. Ax<=b

#constraint_info: should be a dictory with the two items: nonneg:bool and expressions:list. The former is True when the domain is nonnegtive, 
#the latter is the list of specific expressions, like, a @ x <= b.

def argmin_sfnc(alpha,beta,x0,constraints_info):
    return (-alpha/beta + x0)

def argmin_sflc(alpha,beta,x0,constraints_info):
    nonneg = constraints_info['nonneg']
    expressions = constraints_info['expressions']
    n = x0.size(0)
    I = np.eye(n)
    c_np = (alpha/beta).clone().numpy()
    z = cp.Variable(n, nonneg = nonneg)

    obj = 0.5 * cp.quad_form(z-x0, I) + c_np @ z

    constraints = []
    for a,b in expressions:
        constraints.append(a @ z == b)

    prob = cp.Problem(cp.Minimize(obj), constraints)
    prob.solve()

    return torch.from_numpy(z.value).to(x0.device).type_as(x0)

class AGDA(Optimizer):
    def __init__(self, params, beta_bar, r_bar, fun, gun, subproblem_solving_method = 'sfnc',constraints_info = []):
        defaults = dict(beta_bar=beta_bar, r_bar=r_bar, fun=fun, gun=gun)
        super().__init__(params, defaults)
        self.first_step = True
        self.second_step = True
        self.iters = 0
        self.amount_of_call_for_value_oracle = 0
        self.amount_of_call_for_gradient_oracle = 0
        self.constraints_info = constraints_info

        if subproblem_solving_method == 'sfnc':
            self.argmin = argmin_sfnc
        elif subproblem_solving_method == 'sflc':
            self.argmin = argmin_sflc
        else:
            raise ValueError('Undefined subproblem solving method. Pls use agda.help() to obtain the help, or you can modify agda.py to support the other method.')

    def grad(self, fun, x):
        x = x.clone().detach().requires_grad_(True)
        val = fun(x)
        val.backward()
        return x.grad
    

    def step(self, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        group = self.param_groups[0]
        params = group['params']
        fun = group['fun']
        gun = group['gun']
        constraints_info = self.constraints_info

        y_k = parameters_to_vector(params).detach()
        dtype = y_k.dtype
        device = y_k.device

        if self.first_step:
            group['v'] = y_k.clone()
            group['A'] = torch.tensor(0., dtype=dtype, device=device)
            group['a'] = torch.tensor(group['r_bar'], dtype=dtype, device=device)
            group['r_bar'] = torch.tensor(group['r_bar'], dtype=dtype, device=device)
            group['beta'] = group['beta_bar']
            group['beta_m1'] = torch.tensor(0., dtype=dtype, device=device)
            group['r'] = torch.tensor(0., dtype=dtype, device=device)
            group['best_solution'] = y_k.clone()
            group['init_buf'] = y_k
            group['beta_1'] = 1

            group['ag'] = torch.zeros_like(y_k)
            self.first_step = False

        ag_k = group['ag']
        v_k = group['v']
        x0 = group['init_buf']
        A_k = group['A']
        a_k = group['a']
        r_bar_km1 = group['r_bar']
        beta_k = group['beta']
        beta_km1 = group['beta_m1']
        beta_1 = group['beta_1']


        r_k = torch.norm(x0 - v_k)
        r_bar_k = torch.maximum(r_bar_km1, r_k)
        a_k1 = torch.pow((A_k.sqrt() + r_bar_k.sqrt()),2) - A_k
        A_k1 = A_k + a_k1
        tau_k = a_k1 / A_k1
        best_solution = group['best_solution']

        lk_iter = 0
        line_search_base = 1 + 1/(self.iters + 1)
        first_step_line_search = True
        tolerance =  beta_1 * ((1+self.iters) ** (-2))
        lhs = 0
        rhs = 0
        max_binary_search = 100
        binary_search = 0

        while True:
            self.amount_of_call_for_value_oracle += 1

            if first_step_line_search:
                beta_k1 = beta_k *((line_search_base) ** lk_iter)
            else:
                beta_k1 = (lhs + rhs)/2
                binary_search += 1

            eta_k = (beta_k1 * r_bar_k ** 2- beta_k * r_bar_km1 ** 2)  / (8 * a_k1)
            z_k1 = tau_k * v_k + (1 - tau_k) * y_k


            ag_k1 = ag_k + (a_k1.item() * self.grad(fun, z_k1).detach())

            v_k1 = self.argmin(ag_k1,beta_k1,x0,constraints_info)
            y_k1 = tau_k * v_k1 + (1 - tau_k) * y_k


            fy = fun(y_k1)
            zf = z_k1.clone().detach().requires_grad_(True)
            fz = fun(zf)
            fz.backward()
            grad_fx = zf.grad
            grad_dot_yx = torch.dot(grad_fx.type_as(z_k1), (y_k1 - z_k1))
            norm2 = torch.norm(y_k1 - z_k1) ** 2
            breg_penalty = beta_k1 / (64 * tau_k ** 2 * A_k1) * norm2
            stop_l = fz.detach() + grad_dot_yx + breg_penalty + tau_k * eta_k / 2

            if first_step_line_search and fy <= stop_l:
                if lk_iter == 0:
                    break
                first_step_line_search = False
                lhs = beta_k1/line_search_base
                rhs = beta_k1
            elif not first_step_line_search:
                if rhs - lhs <= tolerance or binary_search <= max_binary_search:
                    beta_k1 = rhs
                    break
                else:
                    if stop_l >= 0:
                        rhs = beta_k1 
                    else:
                        lhs = beta_k1
            lk_iter += 1

        group['ag'] = ag_k1
        group["A"] = A_k1
        group['a'] = a_k1
        group["r_bar"] = r_bar_k
        group["beta"] = beta_k1
        group["beta_m1"] = beta_k
        group['v'] = v_k1.detach()

        if (self.second_step):
            group['beta_1'] = beta_k1

        self.iters += 1
        self.amount_of_call_for_value_oracle += 1

        if fun(y_k1.detach()) + gun(y_k1.detach()) < fun(best_solution) + gun(best_solution):
            group['best_solution'] = y_k1.detach()

        with torch.no_grad():
            vector_to_parameters(y_k1, params)

        return loss

    def get_best_solution(self,verbose = False):
        group = self.param_groups[0]
        if verbose:
            print(group['best_solution'])
        return group['best_solution']
    
    def get_amount_of_call_for_value_oracle(self):
        return self.amount_of_call_for_value_oracle
    
    def get_amount_of_call_for_gradient_oracle(self):
        return self.amount_of_call_for_gradient_oracle
    
    def get_amount_of_call_for_oracle(self):
        return self.amount_of_call_for_value_oracle + self.amount_of_call_for_gradient_oracle
    
    def help(self):
        print('''
                sfnc: simple function with no constrains, i.e., psi(x) = f(x), st. x in R^d.\\
                sflc: simple function with linear constrains, i.e. psi(x) = f(x), st. Ax<=b.\\
                constraint_info: should be a dictory with the two items: nonneg:bool and expressions:list. The former is True when the domain is nonnegtive, \\
                the latter is the list of specific expressions, like, a @ x <= b.
              '''
            )