import torch
import argparse
import json
import os
import random
import sys
import time
import numpy as np
import pandas as pd
import pickle
from pathlib import Path
from torch.utils.data import DataLoader
from tqdm import tqdm
# import qpth
import cvxpy as cp
from cvxpylayers.torch import CvxpyLayer

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--solver', default = 'ffo', choices=['ffo', 'BR', 'cvxpylayer'])
    parser.add_argument('--lr', default = 1e-3, type = float)
    parser.add_argument('--xdim', default = 100, type = int)
    parser.add_argument('--ydim', default = 500, type = int)
    parser.add_argument('--n-constraints', default = 100, type = int)
    parser.add_argument('--seed', default = 0, type = int)
    parser.add_argument('--steps', type=int, default=20)
    parser.add_argument('--n-iterations', type=int, default=1000)
    parser.add_argument('--eps', type=float, default=0.01)
    parser.add_argument('--delta', type=float, default=0.01)
    parser.add_argument('--folder', type=str, default='test')

    args = parser.parse_args()
    
    seed = args.seed
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    lr = args.lr
    eps = args.eps
    delta = args.delta
    solver = args.solver
    folder = args.folder

    # Parameters
    x_dim = args.xdim
    y_dim = args.ydim
    n_constraints = args.n_constraints
    D = 1 # gradient clipping size

    L = torch.rand(y_dim, y_dim)
    L_norm = torch.norm(L) # torch.norm(Q)
    c = torch.rand(y_dim)
    c = c / torch.norm(c) # normalize it
    reg = 1
    Q = L.t() @ L
    Q = Q + reg * torch.eye(y_dim)  # PSD matrix, normalize the condition number of Q
    P = torch.rand(x_dim, y_dim)
    A = torch.rand(n_constraints, y_dim, x_dim) # A y - b \leq 0
    b = torch.rand(n_constraints, x_dim)

    c_cp = c.numpy()
    Q_np = Q.numpy()
    P_np = P.numpy()
    A_np = A.numpy()
    b_np = b.numpy()

    print('Verifying if the random samples are consistent...')
    print('c average: {}'.format(np.mean(c_cp)))
    print('Q average: {}'.format(np.mean(Q_np)))
    print('P average: {}'.format(np.mean(P_np)))
    print('A average: {}'.format(np.mean(A_np)))
    print('b average: {}'.format(np.mean(b_np)))

    cvxpy_solver = cp.SCS
    
    # Define functions (pytorch versions)
    f = lambda x,y: c @ y + 0.01 * torch.norm(x)**2 + 0.01 * torch.norm(y)**2
    g = lambda x,y: 1/2 * y.t() @ Q @ y + x.t() @ P @ y
    h = lambda x,y: ((A @ x) @ y) - b @ x # This looks strange but it is basically [x^\top A_i y - x^\top b_i]_i

    # Compute gradient using pytorch
    x = torch.rand(x_dim, requires_grad=True) # random initialization
    optimizer = torch.optim.Adam([x], lr=lr)
    # optimizer = torch.optim.SGD([x], lr=lr)

    # Output file
    folder_path = 'results/{}_bilinear'.format(folder)
    if os.path.exists(folder_path) == False:
        os.mkdir(folder_path)
    directory_path = folder_path + '/ydim{}'.format(y_dim)
    if os.path.exists(directory_path) == False:
        os.mkdir(directory_path)

    f_output = open(directory_path + '/{}_eps{}_seed{}.txt'.format(solver, eps, seed), 'w')
    f_output.write('Iteration, loss, time \n')

    grad_list = []

    if solver == 'cvxpylayer':
        # Cvxpy (differentiable optimization) approach
        for i in range(args.n_iterations):
            start_time = time.time()
            # Perturbation
            s = torch.normal(torch.zeros(x_dim), torch.ones(x_dim)) * delta
            xx = x + s
            A_xx = A @ xx
            b_xx = b @ xx

            # cvxpylayer
            x_cp = cp.Parameter(x_dim)
            A_cp = cp.Parameter((n_constraints, y_dim))
            b_cp = cp.Parameter(n_constraints)
            y_cp = cp.Variable(y_dim)
            q_cp = x_cp.T @ P_np
            objective = cp.Minimize(0.5 * cp.quad_form(y_cp, Q_np) + q_cp @ y_cp)
            constraints = [A_cp @ y_cp - b_cp <= 0]
            problem = cp.Problem(objective, constraints)
            cvxpylayer = CvxpyLayer(problem, parameters=[x_cp, A_cp, b_cp], variables=[y_cp])
    
            solution, = cvxpylayer(xx, A_xx, b_xx)
            loss = f(xx, solution)  # + 1/2 * x @ x
            gradient = torch.autograd.grad(loss, xx)[0]
            gradient = torch.clamp(gradient, min=-D, max=D)
            with torch.no_grad():
                # x += delta
                # x -= gradient * lr
                # x.grad = -delta
                x.grad = gradient
                grad_list.append(gradient.numpy().copy())

            optimizer.step()
            optimizer.zero_grad()
            time_elapsed = time.time() - start_time
            print(i, loss.item(), time_elapsed)
            f_output.write('{}, {}, {} \n'.format(i, loss, time_elapsed))

    elif solver == 'ffo':
        # Our algorithm
        eps_abs = 1e-6 # Accuracy of the inner optimization problem and the Lagrangian optimization problem
        warm_start = True
        lamb_init = 1 / eps
        lamb = lamb_init
        delta = torch.zeros(x_dim)
        y_lamb_opt = None
        for i in range(args.n_iterations):
            start_time = time.time()
            # Perturbed x to get xx
            # s = torch.rand(x_dim)
            s = torch.normal(torch.zeros(x_dim), torch.ones(x_dim)) * delta
            xx = x + s
            A_xx = A @ xx
            b_xx = b @ xx

            # Pre-solve inner optimization problem:  min_y max_\gamma g(x,y) + \gamma^\top h(x,y)
            x_cp = xx.detach().numpy()
            y_cp = cp.Variable(y_dim)
            q_cp = x_cp.T @ P_np
            objective = cp.Minimize(0.5 * cp.quad_form(y_cp, Q_np) + q_cp @ y_cp)
            constraints = [(A_np @ x_cp) @ y_cp - b_np @ x_cp <= 0]
            if y_lamb_opt == None:
                y_cp.value = y_lamb_opt

            problem = cp.Problem(objective, constraints)
            problem.solve(solver=cvxpy_solver, eps_abs=eps_abs, warm_start=warm_start)

            constrained_inner_problem_time = time.time() - start_time
            mid_time = time.time()

            y_opt = y_cp.value
            gamma_opt = constraints[0].dual_value # We only have one set of constraints

            # Checking active constraints
            h_opt_np = (A_np @ x_cp) @ y_opt - b_np @ x_cp # Bilinear constraints
            active_constraints = (np.abs(h_opt_np) < 1e-4) * (gamma_opt > 1e-4)

            # print('optimal y:', y_opt)
            # print('optimal gamma:', gamma_opt)
    
            # Solve Lagrangian optimization problem: min_y f(x,y) + \lambda ( g(x,y) + \gamma^\top h(x,y) - g(x,y*) - \gamma^\top h(x,y*) ) + 1/2 * \lambda^2 * |h(x,y*)|^2
            y_cp = cp.Variable(y_dim)
            f_cp = c_cp.T @ y_cp + 0.01 * cp.sum_squares(x_cp) + 0.01 * cp.sum_squares(y_cp)
            g_cp = 0.5 * cp.quad_form(y_cp, Q_np) + q_cp @ y_cp
            g_opt_cp = 0.5 * cp.quad_form(y_opt, Q_np) + q_cp @ y_opt
            h_opt_cp = (A_np @ x_cp) @ y_opt - b_np @ x_cp

            h_cp = (A_np @ x_cp) @ y_cp - b_np @ x_cp
            if sum(active_constraints) == 0:
                objective = cp.Minimize( f_cp + lamb * (g_cp + gamma_opt.T @ h_cp - g_opt_cp - gamma_opt.T @ h_opt_cp))
            else:
                h_cp_active = (A_np[active_constraints] @ x_cp) @ y_cp - b_np[active_constraints] @ x_cp
                objective = cp.Minimize( f_cp + lamb * (g_cp + gamma_opt.T @ h_cp - g_opt_cp - gamma_opt.T @ h_opt_cp) + 0.5 * lamb**2 * cp.sum_squares(h_cp_active)  )
            y_cp.value = y_opt
            problem = cp.Problem(objective, constraints)
            problem.solve(solver=cvxpy_solver, eps_abs=eps_abs, warm_start=warm_start)
            lagrangian_time = time.time() - mid_time
            
            y_lamb_opt = y_cp.value

            # Convert numpy to tensor
            y_opt = torch.tensor(y_opt).float()
            y_lamb_opt = torch.tensor(y_lamb_opt).float()
            gamma_opt = torch.tensor(gamma_opt).float()

            # Compute the final Lagrangian and track gradient over x
            if sum(active_constraints) == 0:
                constraint_violation = 0
            else:
                constraint_violation = torch.norm(h(xx, y_lamb_opt)[active_constraints])
            final_lagrangian = f(xx,y_lamb_opt) + lamb * (g(xx, y_lamb_opt) + gamma_opt.T @ h(xx, y_lamb_opt) - g(xx, y_opt) - gamma_opt.T @ h(xx, y_opt)) + 0.5 * lamb**2 * constraint_violation**2
            # final_lagrangian.backward()
            # x.grad = torch.clamp(x.grad, max=1, min=-1)

            # Gradient and variable update
            D = 1
            gradient = torch.autograd.grad(final_lagrangian, xx)[0]
            gradient = torch.clamp(gradient, min=-D, max=D)
            with torch.no_grad():
                # x += delta
                # x -= gradient * lr
                # x.grad = -delta
                x.grad = gradient
                grad_list.append(gradient.numpy().copy())
                # print(gradient)
            # delta = gradient * lr # torch.clamp(delta - lr * gradient, min=-D, max=D)
            optimizer.step()
            optimizer.zero_grad()

            # Compute hyperobjective (this is not accurate but should be close enough)
            loss = f(xx, y_opt).item()

            # lambda update # TODO
            lamb = lamb_init

            # Output file
            time_elapsed = time.time() - start_time
            print(i, np.sum(active_constraints), loss, time_elapsed, constrained_inner_problem_time, lagrangian_time)
            f_output.write('{}, {}, {}, {}, {} \n'.format(i, loss, time_elapsed, constrained_inner_problem_time, lagrangian_time))

    f_output.close()

    with open(directory_path + '/{}_eps{}_seed{}.pickle'.format(solver, eps, seed), 'wb') as handle:
        pickle.dump(grad_list, handle)

    print(directory_path + '/{}_eps{}_seed{}.pickle'.format(solver, eps, seed))
    
