import torch
from torch.optim import Optimizer
from torch.nn.utils import parameters_to_vector, vector_to_parameters
from scipy.optimize import minimize
import numpy as np
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"



class AGDA_LSFM(Optimizer):
    def __init__(self, params, beta_bar, r_bar_fix, gun, xi_func, subproblem_solving_method = 'non composition with 2norm',quad_factor = 1,re_factor = 1,eps = 1e-5,
                 subproblem_solving_method2 = 'non composition with 2norm'):
        defaults = dict(beta_bar=beta_bar, r_bar_fix=r_bar_fix, gun=gun)
        super().__init__(params, defaults)
        self.first_step = True
        self.xi_func = xi_func
        self.__subproblem_solving_method = subproblem_solving_method
        self.__subproblem_solving_method2 = subproblem_solving_method2
        self.quad_factor = quad_factor
        self.re_factor = re_factor
        self.eps = eps
        self.iters = 0

    def grad(self, fun, x):
        x = x.clone().detach().requires_grad_(True)
        val = fun(x)
        val.backward()
        return x.grad

    def step(self, closure=None, fun = None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        group = self.param_groups[0]
        params = group['params']
        gun = group['gun']
        xi_func = self.xi_func
        subproblem_solving_method = self.__subproblem_solving_method
        subproblem_solving_method2 = self.__subproblem_solving_method2

        y_k = parameters_to_vector(params).detach()
        dtype = y_k.dtype
        device = y_k.device

        if self.first_step:
            group['v'] = y_k.clone()
            group['A'] = torch.tensor(0., dtype=dtype, device=device)
            group['a'] = torch.tensor(group['r_bar_fix'], dtype=dtype, device=device)
            group['r_bar'] = torch.tensor(group['r_bar_fix'], dtype=dtype, device=device)
            group['beta'] = group['beta_bar']
            group['r'] = torch.tensor(0., dtype=dtype, device=device)
            group['best_solution'] = y_k.clone()
            group['init_buf'] = y_k
            group['be'] = 1e10
            group['be_k'] = y_k.clone()
            group['d_k'] = torch.tensor(0., dtype=dtype, device=device)

            
            group['ag'] = torch.zeros_like(y_k)
            group['phi'] = lambda x: group['beta'] * xi_func(y_k, x)
            self.first_step = False


        ag_k = group['ag']
        v_k = group['v']
        x0 = group['init_buf']
        A_k = group['A']
        a_k = group['a']
        r_bar_km1 = group['r_bar']
        beta_k = group['beta']
        d_k = group['d_k']
        quad_factor = self.quad_factor


        r_k = torch.norm(x0 - v_k)
        r_bar_k = torch.stack([r_bar_km1, r_k, d_k], dim=0).max(dim=0).values
        a_k1 = torch.pow((A_k.sqrt() + r_bar_k.sqrt()),2) - A_k
        A_k1 = A_k + a_k1
        tau_k = a_k1 / A_k1
        best_solution = group['best_solution']

        def min_obj(y_np): 
            y = torch.from_numpy(y_np).type(ag_k1.dtype)
            val = torch.dot(ag_k,y) + A_k1 * gun(y) + beta_k * xi_func(x0, y)
            print('val:',val)
                        
            return val.item()
    
        if type(subproblem_solving_method) == str:
            if subproblem_solving_method == 'scipy':
                res = minimize(min_obj, z_k1.cpu().numpy()).x
            elif subproblem_solving_method == 'non composition with 2norm':

                res = (-ag_k/beta_k + x0).numpy()

            elif subproblem_solving_method == 'copt':
                raise Exception('Copt has not been involved so far.')
                res = None

            else:
                raise Exception('This method has not been involved so far.')
        elif type(subproblem_solving_method).__name__ == 'function':
            res = subproblem_solving_method(ag_k,beta_k,gun,x0)

        
        v_k1 = torch.from_numpy(res).to(device).type_as(x0)    
        z_k1 = tau_k * v_k1 + (1 - tau_k) * y_k
        ag_k1 = ag_k + (a_k1.item() * self.grad(fun, z_k1).detach())

        def min_x_bar(y_np):
            y = torch.from_numpy(y_np).type(ag_k1.dtype)
            val = torch.dot(a_k1.item() * self.grad(fun, z_k1).detach(),y) + a_k1 * gun(y) + beta_k * xi_func(v_k1, y)
            print('val:',val)
            return val



        if type(subproblem_solving_method2) == str:
            if subproblem_solving_method == 'scipy':
                res = minimize(min_obj, z_k1.cpu().numpy()).x
            elif subproblem_solving_method == 'non composition with 2norm':

                res = (-(a_k1.item() * self.grad(fun, z_k1).detach())/beta_k + v_k1).numpy()

            elif subproblem_solving_method == 'copt':
                raise Exception('Copt has not been involved so far.')
                res = None
            else:
                raise Exception('This method has not been involved so far.')
        elif type(subproblem_solving_method).__name__ == 'function':
            res = subproblem_solving_method(a_k1.item() * self.grad(fun, z_k1).detach(),beta_k,gun,v_k1)
        
        x_bar_k = torch.from_numpy(res).to(device).type_as(x0)

        y_k1 = tau_k * x_bar_k + (1 - tau_k) * y_k

        yf = y_k1.clone().detach().requires_grad_(True)
        fy = fun(yf)
        zf = z_k1.clone().detach().requires_grad_(True)
        fz = fun(zf)
        fz.backward()
        grad_fx = zf.grad
        fy.backward()
        grad_fy = yf.grad
        grad_dot_yx = torch.dot(grad_fy.type_as(y_k1) - grad_fx.type_as(z_k1), (y_k1 - z_k1))
        norm2 = torch.norm(y_k1 - z_k1) ** 2
        breg_penalty = beta_k * norm2
        stop_l = 64 * tau_k ** 2 * A_k1 * ( grad_dot_yx ) - quad_factor  *  breg_penalty

        beta_k1 = max(beta_k + stop_l/(norm2 + 32 * tau_k ** 2 * r_bar_k ** 2), beta_k)

        group['ag'] = ag_k1
        group["A"] = A_k1
        group['a'] = a_k1
        group["r_bar"] = r_bar_k
        group["beta"] = beta_k1
        group['v'] = v_k1.detach()
        group['be'] = min(group['be'], r_bar_k/A_k1)
        group['d_k'] = torch.norm(x_bar_k - x0, p = 2)
        self.iters += 1

        if group['be'] >= r_bar_k/A_k1:
            group['be_k'] = y_k1.detach()
            group['be'] = r_bar_k/A_k1


        with torch.no_grad():
            vector_to_parameters(y_k1, params)

        return loss

    def get_best_solution(self,verbose = False):
        group = self.param_groups[0]
        if verbose:
            print(group['best_solution'])
        return group['be_k']
    
    def theory_bound(self):
        group = self.param_groups[0]
        A_k = group["A"]
        r_bar_k = group["r_bar"]
        beta_k = group["beta"]
        factor = beta_k/A_k
        return factor,r_bar_k
    
    def coeff(self):
        group = self.param_groups[0]
        return group['A']/group['beta']