#####################################################################
# Main code file for:
#  Donti, Agarwal, et al (2021). "Adversarially robust learning for 
#  security-constrained optimal power flow." Conference on Neural
#  Information Processing Systems (NeurIPS).
#####################################################################


# Power flow and optimal power flow solvers
from lib.parseSUGAR import parse
from SUGAR import main as opf
from classes.SUGAROpt import SUGAROpt

# Gradient computation functions
from linear_sensitivity2 import contingency_sensitivity 

import numpy as np
from scipy.sparse.linalg import spsolve
import cvxpy as cp
import matplotlib.pyplot as plt
import os
import argparse

def main():
    parser = argparse.ArgumentParser(description='Adversarial SCOPF')
    parser.add_argument('--filepath', type=str, 
        default=os.path.join('raw', 'Trial2_Real-Time-10x14', 'Network_02R-074-1', 'scenario_1'))
    parser.add_argument('--alphaInitType', type=str, choices=['random', 'zeros', 'prev'],
        default='prev')
    parser.add_argument('--maxOuterEpochs', type=int, default=10)
    parser.add_argument('--maxInnerEpochs', type=int, default=20)
    parser.add_argument('--alphaChangeTol', type=float, default=1e-3)
    parser.add_argument('--k', type=int, default=3)
    args = parser.parse_args()

    
    ## Construct case data
    case_name_full = os.path.join(args.filepath, 'case.raw')
    con_file = os.path.join(args.filepath, 'case.con')
    inl_file = os.path.join(args.filepath, 'case.inl')
    rop_file = os.path.join(args.filepath, 'case.rop')

    opt = SUGAROpt(case_name_full, con_file, inl_file, rop_file)
    assert(opt.partial_contingency), "Code assumes partial contingencies"

    case_data = parse(opt) 
    case_data['contingency'] = None
    case_data['umax_P_con'] = None
    case_data['umin_P_con'] = None


    ## Get base case ACOPF solution (alpha = 0) for initialization
    opt.base_gauss_step = False
    opt.contingency_case = 0
    V_init = []

    # Set current alpha and prev alpha (for network stepping) to 0
    for alpha_type in ['alpha_gen', 'alpha_branch', 'alpha_xfmr']:
        alpha_dim = len(case_data[alpha_type])
        case_data[alpha_type] = np.zeros(alpha_dim)
        case_data['prev_{}'.format(alpha_type)] = np.zeros(alpha_dim)
    case_data['alpha_load'] = np.ones(len(case_data['load']))  # TODO Aayushya: verify, should it be 1s?

    (simID, flag_success, iteration_count_final, cs_eps_final, cost,bus_base, generator_base,
        V_base, Y_base, _,_,_,_, case_data, opt) = opf(opt, case_data, V_init)

    case_data['bus'] = bus_base
    case_data['generator'] = generator_base
    opt.cs_init = cs_eps_final
    cs_base = cs_eps_final
    
    V_init = np.copy(V_base)
    V_base_init = np.copy(V_base)


    ## Main attack/defense loop

    con_loss = np.asarray([])
    def_loss = np.asarray([])
    total_loss = np.asarray([])
    for outer_epoch in range(args.maxOuterEpochs):

        ###################################################
        # ATTACK
        ###################################################

        # Initialize alpha
        init_alpha(case_data, args)
        
        max_delta_alpha_hist = np.asarray([])
        for inner_epoch in range(args.maxInnerEpochs):
            print("==========================================")
            print("STARTING ATTACK: ", inner_epoch)

            ## Run contingency SUGAR
            opt.stamp_pf = True
            opt.contingency_case += 1
            opt.homotopy_enable = False
            opt.NStep_end_factor = 0
            case_data['contingency'] = None
            opt.base_gauss_step = False

            (simID, flag_success, iteration_count_final, cs_eps_final, cost, bus_con, generator_con, 
                con_sol, Y_con, umax_P_con, umin_P_con, umax_d_con, umin_d_con, case_data, opt) \
                = opf(opt, case_data, V_init)
            
            # Receive coupling variables for base gauss step
            case_data['umax_P_con'] = umax_P_con
            case_data['umin_P_con'] = umin_P_con
            case_data['umax_d_con'] = umax_d_con
            case_data['umin_d_con'] = umin_d_con


            ## Update alphas 

            # Get loss gradients
            delta_alpha_gen, delta_alpha_branch, delta_alpha_xfmr = \
                dLoss_dalpha(V_base, con_sol, Y_con, case_data, opt)
            delta_alpha_xfmr = 0   # Not dealing with transformer contingencies

            # Save normalized loss gradients
            delta_norm = 0.5
            grad_norm = np.sum([np.sum(np.abs(x)) for x in 
                [delta_alpha_gen, delta_alpha_branch, delta_alpha_xfmr]])
            if grad_norm > delta_norm:
                scale_factor = delta_norm / grad_norm
                delta_alpha_gen = delta_alpha_gen * scale_factor
                delta_alpha_branch = delta_alpha_branch * scale_factor
                delta_alpha_xfmr = 0 # Not dealing with transformer contingencies

            # Get max change
            max_delta_alpha = np.max([np.max(np.abs(x))
                    for x in [delta_alpha_gen, delta_alpha_branch, delta_alpha_xfmr]])
            max_delta_alpha_hist = np.append(max_delta_alpha_hist, max_delta_alpha)

            # Save previous alpha
            case_data['prev_alpha_gen'] = case_data['alpha_gen']
            case_data['prev_alpha_branch'] = case_data['alpha_branch']
            case_data['prev_alpha_xfmr'] = case_data['alpha_xfmr']

            # Take step in alpha
            alpha_gen_temp = case_data['alpha_gen'] + delta_alpha_gen
            alpha_branch_temp = case_data['alpha_branch'] + delta_alpha_branch
            alpha_xfmr_temp = case_data['alpha_xfmr'] + delta_alpha_xfmr

            # Project alpha
            case_data['alpha_gen'], case_data['alpha_branch'], case_data['alpha_xfmr'] = \
                project_alpha(alpha_gen_temp, alpha_branch_temp, alpha_xfmr_temp, args.k)
 

            ## Compute losses

            con_loss_iter = calcConLoss(V_base, con_sol, case_data, opt)
            gen_cost = genCost(con_sol, case_data, opt)
            print('Contingency Loss: {:.4f}'.format(con_loss_iter))
            print('Gen cost: {:.4f}'.format(gen_cost))

            base_loss_iter = calcDefenseLoss(V_base, case_data, opt)
            print('Base case loss: {:.4f}'.format(base_loss_iter))
      
            total_loss_iter = con_loss_iter + base_loss_iter
            print('Total loss: {:.4f}'.format(total_loss_iter))

            con_loss = np.append(con_loss, con_loss_iter)
            def_loss = np.append(def_loss, base_loss_iter)
            total_loss = np.append(total_loss, total_loss_iter)

            # Copy the solution to be used for the next contingency
            V_init = np.copy(con_sol)

            print("Max delta alpha: {:.4f}".format(max_delta_alpha))
            print("==========================================")

            if max_delta_alpha < args.alphaChangeTol:
                break

        plt.plot(range(inner_epoch), max_delta_alpha_hist)
        plt.show()
        plt.plot(range(inner_epoch), con_loss)
        plt.show()

        
        ###################################################
        # DEFENSE
        ###################################################
        print("==========================================")
        print("STARTING DEFENSE ")

        opt.contingency_case = 0
        opt.stamp_pf = False
        opt.homotopy_enable = False
        case_data['contingency'] = None
        opt.base_gauss_step = True
        opt.cs_init = cs_base

        # Compute updated dispatch
        (simID, flag_success, iteration_count_final, cs_eps_final, cost, bus_base, generator_base,\
            V_base, Y_base,_,_,_,_, case_data, opt) = opf(opt, case_data, V_base_init)

        cs_base = cs_eps_final

        loss = calcDefenseLoss(V_base, case_data, opt)
        gen_cost = genCost(con_sol, case_data, opt)
        print('Loss: {:.4f}'.format(loss))
        print('Gen cost: {:.4f}'.format(gen_cost))

        con_loss_iter = calcConLoss(V_base, con_sol, case_data, opt)

        total_loss_iter = con_loss_iter + loss
        con_loss = np.append(con_loss, con_loss_iter)
        def_loss = np.append(def_loss, base_loss_iter)

        total_loss = np.append(total_loss, total_loss_iter)

        print("======= END DEFENSE ===========")
    
    t = range(len(def_loss))
    plt.plot(t, def_loss,'bs')
    plt.show()

    plt.plot(t, con_loss,'r--', t, def_loss,'bs')
    plt.show()
    plt.plot(t, total_loss)
    plt.show()


def init_alpha(case_data, args):
    if args.alphaInitType == 'zeros':
        case_data['alpha_gen'] = np.zeros_like(case_data['alpha_gen'])
        case_data['alpha_branch'] = np.zeros_like(case_data['alpha_branch'])
        case_data['alpha_xfmr'] = np.zeros_like(case_data['alpha_xfmr'])
    elif args.alphaInitType == 'random':
        alpha_sizes = [len(case_data[x]) for x in ['alpha_gen', 'alpha_branch', 'alpha_xfmr']]
        alphas = np.random.uniform(low=0, high=1, size=sum(alpha_sizes))
        case_data['alpha_gen'], case_data['alpha_branch'], case_data['alpha_xfmr'] = \
            project_alpha(*np.split(alphas, np.cumsum(alpha_sizes)[:-1]), args.k)
    # no action if init type is "prev" (keep alpha values as is)
    
    case_data['alpha_load'] = np.ones(len(case_data['load']))


def project_alpha(a_gen, a_branch, a_xfmr, k):
    ng, nb, nx = len(a_gen), len(a_branch), len(a_xfmr)
    x = cp.Variable(ng + nb + nx)
    objective_func = cp.Minimize(cp.norm((x - np.hstack([a_gen, a_branch, a_xfmr])), 2))
    constraints = [0 <= x, sum(x) <= k]
    problem = cp.Problem(objective_func, constraints)
    result = problem.solve(solver = cp.SCS)
    res = x.value
    return res[:ng], res[ng:-nx], res[-nx:]


def genCost(sol, case_data, opt):

    gen_cost_a = case_data['gen_cost_a']
    gen_cost_b = case_data['gen_cost_b']
    gen_cost_c = case_data['gen_cost_c']

    Pg = sol[case_data['P_index']]

    if opt.use_rop_cost:
        gen_cost = np.multiply(gen_cost_a, np.multiply(Pg,Pg)) + np.multiply(gen_cost_b, Pg)
        total_cost = np.sum(gen_cost)
    else:
        gen_cost = np.multiply(gen_cost_b, Pg)
        total_cost = np.sum(gen_cost)

    return total_cost 


def calcDefenseLoss(sol_base, case_data, opt):

    P_base = sol_base[case_data['P_index']]

    Pmax = case_data['Pmax']
    Pmin = case_data['Pmin']

    infeas_real = sol_base[case_data['Lr_index']]
    infeas_imag = sol_base[case_data['Li_index']]

    Vmag2_base = sol_base[case_data['d_index']]

    Vmax_base = case_data['dmax_base']
    Vmin_base = case_data['dmin_base']
    
    if opt.stamp_current_meas:
        Imag_base = sol_base[case_data['Imag2_index']]
        Imag_max = case_data['Imag_max']

    infeas_base = np.square(infeas_real) + np.square(infeas_imag)

    loss = genCost(sol_base, case_data, opt)
    abase = 50
    loss += abase*np.sum(infeas_base)

    loss += abase * np.sum( (P_base - (Pmin + Pmax)/2 )**2 )

    loss += abase * np.sum( (Vmag2_base - (Vmin_base + Vmax_base)/2)**2)

    # line limits
    if opt.stamp_current_meas:
        loss += abase* np.sum( (Imag_base - Imag_max/2)**2 )
  
    # transformer limits
    base_relative_factor = 1000
    return base_relative_factor*loss


def calcConLoss( sol_base, sol_con, case_data, opt):

    P_base = sol_base[case_data['P_index']]
    P_con = sol_con[case_data['P_index']]

    Pmax = case_data['Pmax']
    Pmin = case_data['Pmin']

    Pmax_ramp = case_data['Pmax_ramp']
    Pmin_ramp = case_data['Pmin_ramp']

    infeas_con_real = sol_con[case_data['Lr_index']]
    infeas_con_imag = sol_con[case_data['Li_index']]

    Vmag2_con = sol_con[case_data['d_index']]

    Vmax_con = case_data['dmax_con']
    Vmin_con = case_data['dmin_con']

    if opt.stamp_current_meas:
        Imag_con = sol_con[case_data['Imag2_index']]
        Imag_max = case_data['Imag_max']

    infeas_con = np.square(infeas_con_real) + np.square(infeas_con_imag)

    loss = genCost(sol_con, case_data, opt)
    a = 50
    loss += a * np.sum(infeas_con)

    loss += a * np.sum( (P_con - (Pmin + Pmax)/2 )**2 )

    loss += a * np.sum( (Vmag2_con - (Vmin_con + Vmax_con)/2)**2)
    loss += a * np.sum( ((P_con - P_base) - (Pmin_ramp + Pmax_ramp)/2)**2 )

    # line limits
    if opt.stamp_current_meas:
        loss += a* np.sum( (Imag_con - Imag_max/2)**2 )

    return loss


def dgenCost(sol, case_data, opt):

    gen_cost_a = case_data['gen_cost_a']
    gen_cost_b = case_data['gen_cost_b']
    gen_cost_c = case_data['gen_cost_c']

    Pg = sol[case_data['P_index']]


    if opt.use_rop_cost:
        d_gen_cost = 2*np.multiply(gen_cost_a, Pg) + gen_cost_b
    else:
        d_gen_cost = gen_cost_b

    return d_gen_cost


def dLoss_dalpha(sol_base, sol_con, Y, case_data, opt):

    dloss_dV = np.zeros(sol_con.shape)

    P_con = sol_con[case_data['P_index']]
    infeas_con_real = sol_con[case_data['Lr_index']]
    infeas_con_imag = sol_con[case_data['Li_index']]

    Vmag2_con = sol_con[case_data['d_index']] ##d values = square of voltage magnitude
    Vmax_con = case_data['dmax_con']
    Vmin_con = case_data['dmin_con']

    Pmax = case_data['Pmax']
    Pmin = case_data['Pmin']

    P_base = sol_base[case_data['P_index']]
    Pmax_ramp = case_data['Pmax_ramp']
    Pmin_ramp = case_data['Pmin_ramp']
    
    a=0.1
    dloss_dV[case_data['P_index']] += dgenCost(sol_con, case_data, opt)
    dloss_dV[case_data['Lr_index']] += 2 * infeas_con_real
    dloss_dV[case_data['Li_index']] += 2 * infeas_con_imag
    dloss_dV[case_data['P_index']] += 2 * a * (P_con - (Pmin + Pmax)/2)
    dloss_dV[case_data['d_index']] += 2 * a * (Vmag2_con - (Vmin_con + Vmax_con)/2)
    dloss_dV[case_data['P_index']] += 2 * a * ((P_con - P_base) - (Pmin_ramp + Pmax_ramp)/2)

    
    if opt.stamp_current_meas:
        Imag = sol_con[case_data['Imag2_index']]
        Imag_max = case_data['Imag_max']
        dloss_dV[case_data['Imag2_index']] += 2 * a * (Imag - Imag_max/2)

    dV = spsolve(Y, dloss_dV, use_umfpack=True)

    # sensitivity of base solution wrt alpha
    (dloss_dalpha_gen, dloss_dalpha_branch, dloss_dalpha_xfmr) = \
        contingency_sensitivity(opt, case_data, dV.reshape(-1,1), sol_con) 

    return dloss_dalpha_gen, dloss_dalpha_branch, dloss_dalpha_xfmr

if __name__ == "__main__":
    main()

