
import sys, os

curr_dir = os.path.abspath(os.getcwd())
sys.path.append(curr_dir)

import numpy as np
import pandas as pd
import random as rd
import datetime, time
import math as m
import argparse
import copy
import json, torch
from pyomo.environ import *
#from utils import *
from pyomo.opt import SolverFactory, TerminationCondition

def flatten_dict(dictionary, parent_key='', separator=';'):
    items = []
    for key, value in dictionary.items():
        # Convert tuple keys to a string format before concatenating
        if isinstance(key, tuple):
            key = '_'.join(map(str, key))  # Convert tuple to string format (e.g., ('a', 'b') -> 'a_b')

        new_key = parent_key + separator + str(key) if parent_key else key

        if isinstance(value, dict):
            # If the value is a dictionary (including empty ones), recurse
            if value:
                items.extend(flatten_dict(value, new_key, separator=separator).items())
            #else:
                # Add empty dictionaries as well
                #items.append((new_key, {}))
        else:
            items.append((new_key, value))
    return dict(items)
def flatten_and_track_mappings(dictionary, separator=';'):
    # Flatten the dictionary using the updated flatten_dict function
    flattened_dict = flatten_dict(dictionary, separator=separator)
    
    # Track the mappings of index to the key split by separator
    mappings = []
    for index, (key, value) in enumerate(flattened_dict.items()):
        # Ensure the key is a string before splitting by separator
        mapped_key = key.split(separator)  # Split the string key by the separator
        mappings.append((index, mapped_key))
    
    # Convert the flattened values to a numpy array of float32, but skip non-numeric values
    flattened_values = []
    for value in flattened_dict.values():
        if isinstance(value, (int, float)):  # Only include numeric types
            flattened_values.append(value)
        elif isinstance(value, dict):  # Skip dictionaries
            flattened_values.append(0)  # Or set a default value for empty or nested dictionaries
    
    # Convert the valid numeric values to a numpy array
    flattened_array = np.array(flattened_values).astype("float32")
    
    return flattened_array, mappings
def build_optimization_model(env):
    regions = env.regions
    generators = env.generators
    transmission_lines = env.transmission_lines
    trans_dict = {l:tuple(l.split('_')) for l in transmission_lines}
    T = env.T
    t_init = 0
    time_all = range(t_init, t_init + T + 1)
    time_add = range(t_init + 1, t_init + T + 1)
    num_gen_0 = {(i,r):0 for i in generators for r in regions}
    bin_trans_0 = {l:0 for l in transmission_lines}
    model = ConcreteModel()
    model.num_gen = Var(generators,regions,time_all,domain = Integers)
    model.add_gen = Var(generators,regions,time_all,domain = Integers)
    has_transmission = len(transmission_lines) > 0
    model.num_gen = Var(generators,regions,time_all,domain = Integers)
    model.add_gen = Var(generators,regions,time_all,domain = Integers)
    if has_transmission:
        model.pow_flow = Var(transmission_lines, time_all)  # Reals by default
        model.bin_trans = Var(transmission_lines, time_all, domain=Binary)
        model.bin_trans_add = Var(transmission_lines, time_add, domain=Binary)
        print(model.bin_trans_add)
    def init_numgen_rule(model, i,r):
        return model.num_gen[i,r,t_init] == num_gen_0[(i,r)]
    model.init_numgen = Constraint(generators,regions, rule=init_numgen_rule)
    def init_addgen_rule(model, i,r):
        return model.add_gen[i,r,t_init] == 0
    
    model.init_numgen = Constraint(generators,regions, rule=init_numgen_rule)
    model.init_addgen = Constraint(generators,regions, rule=init_addgen_rule)
    def lb_numgen_rule(model,i,r,t):
        return model.num_gen[i,r,t]>=0
    model.lb_numgen = Constraint(generators,regions,time_all,rule = lb_numgen_rule)
    def ub_numgen_rule(model,i,r,t):
        return model.num_gen[i,r,t]<=env.maxgen[i][r]
    model.ub_numgen = Constraint(generators,regions,time_all,rule = ub_numgen_rule)
    if has_transmission:
        def init_bin_trans_rule(model, l):
            return model.bin_trans[l,t_init] == bin_trans_0[l]   
        model.init_bintrans = Constraint(transmission_lines, rule=init_bin_trans_rule)
        def lb_powflow_rule(model,l,t):
            return model.pow_flow[l,t]>=-env.tlcap[l]*model.bin_trans[l,t]
        model.lb_powflow = Constraint(transmission_lines,time_all,rule = lb_powflow_rule)
        def ub_powflow_rule(model,l,t):
            return model.pow_flow[l,t]<=env.tlcap[l]*model.bin_trans[l,t]
        model.ub_powflow = Constraint(transmission_lines,time_all,rule = ub_powflow_rule)
        def lb_bin_trans_rule(model,l,t):
            return model.bin_trans[l,t]>=model.bin_trans[l,t-1]
        model.lb_bin_trans = Constraint(transmission_lines,time_add,rule = lb_bin_trans_rule)
        def val_add_trans_rule(model,l,t):
            return model.bin_trans_add[l,t] == model.bin_trans[l,t]-model.bin_trans[l,t-1]
        model.val_add_trans = Constraint(transmission_lines,time_add,rule = val_add_trans_rule)
    def lb_addgen_rule(model,i,r,t):
        return model.add_gen[i,r,t]>=0
    model.lb_addgen = Constraint(generators,regions,time_all,rule = lb_addgen_rule)
    def state_rule(model,i,r,t):
        return model.num_gen[i,r,t]==model.num_gen[i,r,t-1]+model.add_gen[i,r,t]
    model.state = Constraint(generators,regions,time_add,rule = state_rule)
    def demand_check_rule(model, r, t):
        demand_t = env.demand[r].get(str(t), 0) if t <= T else 0
        gen_supply = sum(model.num_gen[i, r, t] * env.gencap[i] for i in generators)

        if not has_transmission:
            return demand_t <= gen_supply

        # Net inflow - outflow for region r
        inflow = sum(model.pow_flow[l, t] for l in transmission_lines if r == trans_dict[l][1])
        outflow = sum(model.pow_flow[l, t] for l in transmission_lines if r == trans_dict[l][0])
        return demand_t <= gen_supply + inflow - outflow

    model.demand_check = Constraint(regions, time_add, rule=demand_check_rule)
    def objective_rule(model):
        gen_cost = -sum(
            model.add_gen[i, r, t] * env.installcost["generators"][i]
            for i in generators
            for r in regions
            for t in time_all
        )

        if not has_transmission:
            return gen_cost

        trans_cost = -sum(
            model.bin_trans_add[l, t] * env.installcost["transmission"][l]
            for l in transmission_lines
            for t in time_add
        )
        return gen_cost + trans_cost

    model.obj = Objective(rule=objective_rule, sense=maximize)
    return model

def optimal_simulation(env, solver, tee: bool = True, raise_on_infeasible: bool = True):
    m = build_optimization_model(env)
    opt = solver if hasattr(solver, "solve") else SolverFactory(str(solver))
    results = opt.solve(m, tee=tee)

    term = results.solver.termination_condition
    ok = term in (TerminationCondition.optimal, TerminationCondition.locallyOptimal)

    if not ok and raise_on_infeasible:
        raise RuntimeError(
            f"Optimization did not solve to optimality. Termination condition: {term}"
        )
    actions = []
    
    for t in range(1,env.T+1):
        action = {'addgen':{(i,r): 2*(m.add_gen[i,r,t].value-0)/(env.maxgen[i][r] - 0)-1 for i in env.generators for r in env.regions},
                        'powflow':{l:2*(m.pow_flow[l,t].value + env.tlcap[l])/(env.tlcap[l] + env.tlcap[l])-1 for l in env.transmission_lines}}
        action_flatt, mapp = flatten_and_track_mappings(action)
        print(action_flatt)
        actions.append(action_flatt)
    return actions
