from __future__ import print_function
import argparse
import tqdm
from tqdm import trange
import numpy as np, random
import torch
import os
import sys
import networkx as nx
from collections import defaultdict
import matplotlib.pyplot as plt
sys.path.append(os.getcwd())
from src.envs.scim_env import Network, SupplyChainIventoryManagement
from src.algos.graph_rl_agent import A2C, GNNParser, Actor, Critic
from src.algos.lcp_solver import solveLCP, MPC
from src.algos.stype_policy import s_type_policy_multi_factory
from src.misc.utils import dictsum
# from src.algos.mpc import MPC
from src.algos.inverse import stylized_invOptimization
import pickle
import time

parser = argparse.ArgumentParser(description='A2C-GNN')

# Simulator parameters
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 10)')

# Model parameters
parser.add_argument('--algo', type=str, default='rl',
                    help='defines the algorithm to evaluate (only "rl" can use --test=False)')
parser.add_argument('--test', type=bool, default=True,
                    help='activates test mode for agent evaluation')
parser.add_argument('--cplexpath', type=str, default='/opt/ibm/ILOG/CPLEX_Studio128/opl/bin/x86-64_linux/',
                    help='defines directory of the CPLEX installation')
parser.add_argument('--directory', type=str, default='saved_files',
                    help='defines directory where to save files')
parser.add_argument('--max_episodes', type=int, default=1, metavar='N',
                    help='number of episodes to train agent (default: 20k)')
parser.add_argument('--max_steps', type=int, default=30, metavar='N',
                    help='number of steps per episode (default: T=60)')
parser.add_argument('--no-cuda', type=bool, default=True,
                    help='disables CUDA training')
parser.add_argument('--s_store', type=int, default=10, metavar='S',
                    help='optimal store order-up-to-level (default: 18)')
parser.add_argument('--s_factory', type=int, default=10, metavar='S',
                    help='optimal factory order-up-to-level (default: 100)')

args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if args.cuda else "cpu")

# Define SCIM Simulator Environment
num_w=4
num_s=8
V_W = [i for i in range(num_w)]
V_S = [num_w + i for i in range(num_s)]
V = V_W + V_S
edge_list = [(w, s) for w in V_W for s in V_S]
G = nx.DiGraph()
G.add_nodes_from(V)
G.add_edges_from(edge_list)
for e in G.edges:
    G.edges[e]['capacity'] = 1000 

# Set network parameters
dmax = [10]*num_s
dvar = [2]*num_s
tf=30
factory_nodes = V_W
warehouse_nodes = V_S
production_time=6
product_prices=[15]

production_costs=[3]*num_w
storage_capacities=[30]*(num_w)+[20]*(num_s)
storage_costs=[1]*(num_w)+[2]*(num_s)
edge_costs = np.load('edge_costs.npy')
edge_time = np.load('edge_time.npy')

params = {
    'mT': edge_costs,
    'mO': production_costs
}

# theta = [np.random.uniform(0.1, 0.5, size=len(G.edges))]
# mu = [0]
K = 1

network = Network(G=G, tf=tf, dmax=dmax, dvar=dvar, factory_nodes=factory_nodes,
                  warehouse_nodes=warehouse_nodes, product_prices=product_prices, production_costs=production_costs,
                  storage_capacities=storage_capacities, storage_costs=storage_costs, randomize_graph_args=(None, 'random-tt'), 
                  randomize_demand_args=(None, 'single-od'), edge_costs=edge_costs, edge_time=edge_time)
env = SupplyChainIventoryManagement(network)

def generate_G_B_subsets(V_S, V_W, E):
    import numpy as np
    index_w = {node: idx for idx, node in enumerate(V_W)}
    index_s = {node: idx for idx, node in enumerate(V_S)}

    M = len(E)
    G = np.zeros((len(V_W), M))
    B = np.zeros((len(V_S), M))

    for k, (i, j) in enumerate(E):
        if i in index_w:
            G[index_w[i], k] = 1
        if j in index_s:
            B[index_s[j], k] = 1
    return G, B

G, B = generate_G_B_subsets(V_S, V_W, G.edges)

if args.algo == 'rl':
    # Initialize agent
    parser = GNNParser(env, edge_list=edge_list)
    actor = Actor(node_size=10, edge_size=2, hidden_dim=64, out_channels=1, num_factories=num_w)
    critic = Critic(node_size=10, edge_size=2, hidden_dim=64, out_channels=1)
    model = A2C(env, parser, actor, critic, clip=2, baseline=None, parametrization='Gaus-Dirichlet')

if not args.test:
    #######################################
    #############Training Loop#############
    #######################################

    #Initialize lists for logging
    log = {'train_reward': [], 
           'train_served_demand': [], 
           'train_reb_cost': []}
    train_episodes = args.max_episodes #set max number of training episodes
    T = args.max_steps #set episode length
    epochs = trange(train_episodes) #epoch iterator
    best_reward = -np.inf #set best reward
    model.train() #set model in train mode

    # obs = env.reset() #initialize environment
    f_list, w_list, qs_list, qw_list, v1_list = [],[],[],[],[]
    d_part, cs_part, cw_part = [], [], []
    num_episode = 3
    lambda_ls = []
    demand_return = {}
    for i in range(num_episode):
        obs = env.reset() #initialize environment
        q_current = {i: 0 for i in V}  # Initial inventory
        f_dict = defaultdict(dict)
        w_dict = defaultdict(dict)
        # print('demand:', env.demand)
        for current_num in range(0, T):
            f, qs, qw, w, f0, w0, q1, r0, z0, v1 = MPC(env, current_num, q_current, f_dict, w_dict)
            v1_list.append(v1)
            f_list.append(f[0])
            w_list.append(w[0])
            q_current = q1
            # print('q1:', q_current)
            qs_list.append(qs[0])
            qw_list.append(qw[0])
            f_dict[current_num] = f0
            w_dict[current_num] = w0

            lambda_ls.append(np.array([env.demand[current_num][i] for i in env.scenario.warehouse]))

        for current_num in range(0, T+10):
            demand_return[i*(T+10)+current_num] = env.demand[current_num]

    with open('inv_theta_cvxpy.pkl', 'rb') as f:
        theta = pickle.load(f)
    with open('inv_mu_cvxpy.pkl', 'rb') as f:
        mu = pickle.load(f)

    for i_episode in epochs:
        obs = env.reset() #initialize environment
        if i_episode % 100 == 0:

            q_current = {i: 0 for i in V}  # Initial inventory
            f_dict = defaultdict(dict)
            w_dict = defaultdict(dict)
            obj_mpc = 0
            f_list, w_list, qs_list, qw_list = [],[],[],[]
            for current_num in range(0, T):
                f, qs, qw, w, f0, w0, q1, r0, z0, _ = MPC(env, current_num, q_current, f_dict, w_dict)
                f_list.append(f[0])
                w_list.append(w[0])
                q_current = q1
                # print('q1:', q_current)
                qs_list.append(qs[0])
                qw_list.append(qw[0])
                f_dict[current_num] = f0
                w_dict[current_num] = w0
                obj_mpc = obj_mpc + r0

        episode_reward = 0
        for step in range(T):
            # use Graph-RL policy (RL)
            (prod, ship), (gaus_log_prob, dir_log_prob) = model.select_action(obs, show_log_prob=True)
            prod = np.maximum(prod.detach().cpu().numpy(), 0)
            # if i_episode < 500:
            #     prod = prod + np.random.randint(2,6)
            #     ship = ship.detach() + np.random.uniform(0.1,1) # 也可以给ship加一点值
            # solve LCP
            action, sd = solveLCP(env, params, theta, mu, G, B, K, ship, prod)
            # Take action in environment
            obs, reward, done, info, _, _ = env.step(action)
            episode_reward += reward
            # Store the transition in memory
            model.rewards.append(reward)
            # stop episode if terminating conditions are met
            if done:
                break
        # perform on-policy backprop
        grad_norms = model.training_step()

        # Send current statistics to screen
        epochs.set_description(f"Episode {i_episode+1} | Reward: {episode_reward:.2f} |MPC Objective: {obj_mpc:.2f} | Grad Norms: Actor={grad_norms['a_grad_norm']:.2f}, Critic={grad_norms['v_grad_norm']:.2f}")
        # Checkpoint best performing model
        if episode_reward >= best_reward:
            model.save_checkpoint(path=f"./saved_files/ckpt/1f10s/graph_rl_cvxpy.pth")
            best_reward = episode_reward
        # Log KPIs
        log['train_reward'].append(episode_reward)
        model.log(log, path=f"./{args.directory}/rl_logs/1f10s/graph_rl.pth")
else:

    if args.algo == 'rl':
        # Load pre-trained model
        # model.load_checkpoint(path=f"./saved_files/ckpt/1f10s/graph_rl_cvxpy.pth")
        model.load_checkpoint(path=f"./saved_files/ckpt/nfms/graph_rl_0924.pth")
    test_episodes = args.max_episodes #set max number of training episodes
    T = args.max_steps #set episode length
    epochs = trange(test_episodes) #epoch iterator
    #Initialize lists for logging
    log = {'test_reward': [], 
           'test_served_demand': [], 
           'test_reb_cost': []}
    task_reward_list = []
    ratio_list = []
    
    para = np.load('parameter_inv.npz')
    theta = [para['theta']]
    mu = [para['mu']]
    # print(theta)
    # print(mu)
    mpcobj_list = []
    mpcdemand_list = []
    expdemand_list = []
    for episode in epochs:
        episode_reward = 0
        obs = env.reset()
        done = False
        k = 0
       
        q_current = {i: 0 for i in V}  # Initial inventory
        f_dict = defaultdict(dict)
        w_dict = defaultdict(dict)
        obj_mpc = 0
        q_current_mpc = []
        # print('demand:', env.demand)
        f_list, w_list, qs_list, qw_list = [],[],[],[]
        served_demand_mpc = 0
        start_time = time.time()
        for current_num in range(0, T):
            f, qs, qw, w, f0, w0, q1, r0, z0, _ = MPC(env, current_num, q_current, f_dict, w_dict)
            sd = sum(z0.values())
            served_demand_mpc += sd
            # print(current_num, w[0])
            q_current_mpc.append(np.concatenate((qw[0], qs[0])))
            f_list.append(f[0])
            w_list.append(w[0])
            q_current = q1
            # print('q1:', q_current)
            qs_list.append(qs[0])
            qw_list.append(qw[0])
            f_dict[current_num] = f0
            w_dict[current_num] = w0
            obj_mpc = obj_mpc + r0
        end_time = time.time()
        print('mpc time:', end_time - start_time)
        served_demand_total = 0
        total_ship = 0
        flow_total = 0
        arrival_flow_total = 0
        start_time = time.time()
        while(not done):
            with torch.no_grad():
                a_probs , value = model(obs)
                mu_prod, sigma = a_probs[0][0], a_probs[0][1]
                alpha = a_probs[1]
                prod, ship = mu_prod, alpha / (alpha.sum() + 1e-16) 
                print('mu:', mu_prod)
                action, sd= solveLCP(env, params, theta, mu, G, B, K, ship, prod)
                print('action:', action)

                # prod, ship = s_type_policy_multi_factory(env, args.s_store, args.s_factory)
                # action = (prod, ship)
                # print('prod:', prod)
                # print('ship:', ship)
                # total_ship += sum(action[1].values())
    
            # Take action in environment
            obs, reward, done, info, served_demand, flow = env.step(action)
            episode_reward += reward
            served_demand_total += served_demand
            flow_total += flow
        end_time = time.time()
        print('exp time:', end_time - start_time)
        print('total demand:', sum(env.demand[t][i] for t in range(T) for i in V))
        print('total ship:', flow_total)
        print('served demand total:', served_demand_total)
        print('MPC served demand:', served_demand_mpc)
        print('final inventory:', sum(env.acc[env.time-1][i] for i in V_S))
        q_exp_ls = [[env.acc[t][i] for t in range(T)] for i in V]
        q_current_mpc = np.array(q_current_mpc).transpose()
        q_current_mpc = q_current_mpc.tolist()

        task_reward_list.append(episode_reward)
        # Send current statistics to screen
        epochs.set_description(f"Episode {episode+1} | Reward: {episode_reward:.2f} | MPC:{obj_mpc:.2f} | Aggregated: {np.mean(task_reward_list):.0f} +- {np.std(task_reward_list):.0f}")
        ratio_list.append(episode_reward/obj_mpc)
        mpcobj_list.append(obj_mpc)
        mpcdemand_list.append(served_demand_mpc)
        expdemand_list.append(served_demand_total)
    
    print(ratio_list)
    print('MPC reward:', mpcobj_list, np.mean(mpcobj_list), np.std(mpcobj_list))
    print('exp reward:', task_reward_list, np.mean(task_reward_list), np.std(task_reward_list))
    print('MPC demand:', mpcdemand_list, np.mean(mpcdemand_list), np.std(mpcdemand_list))
    print('exp demand:', expdemand_list, np.mean(expdemand_list), np.std(expdemand_list))

