from __future__ import print_function
import argparse
from tqdm import trange
import numpy as np
import torch

from src.envs.amod_env import Scenario, AMoD
from src.algos.c_sac import SAC
from src.algos.reb_flow_solver import solveRebFlow
from src.misc.utils import dictsum
import random
import pickle
from torch_geometric.data import Data, Batch
import json
import omegaconf
from generate_historical_data import Generate_History
from generate_historical_data import Benchmark, MPC
from generate_historical_data import Verify
from src.algos.opt_solver import solveOpt, RegsolveOpt, ValuesolveOpt
from collections import defaultdict
from src.algos.a2c_gnn import A2C
import matplotlib.pyplot as plt
from matplotlib.backend_bases import ResizeEvent
import warnings
import matplotlib
import pandas as pd
from datetime import datetime


class PairData(Data):
    """
    Store 2 graphs in one Data object (s_t and s_t+1)
    """

    def __init__(self, edge_index_s=None, x_s=None, reward=None, action=None, edge_index_t=None, x_t=None):
        super().__init__()
        self.edge_index_s = edge_index_s
        self.x_s = x_s
        self.reward = reward
        self.action = action
        self.edge_index_t = edge_index_t
        self.x_t = x_t

    def __inc__(self, key, value, *args, **kwargs):
        if key == 'edge_index_s':
            return self.x_s.size(0)
        if key == 'edge_index_t':
            return self.x_t.size(0)
        else:
            return super().__inc__(key, value, *args, **kwargs)


class ReplayData:
    """
    Replay buffer for SAC agents
    """

    def __init__(self, device, rew_scale):
        self.device = device
        self.data_list = []
        self.rew_scale = rew_scale

    def create_dataset(self, edge_index, memory_path, size=60000, st=False, sc=False, store_path=''):
        # w = open(f'replaymemories/{memory_path}.pkl', "rb")

        w = open(store_path, "rb")
        replay_buffer = pickle.load(w)
        data = replay_buffer.sample_batch(size)

        if st:
            mean = data['reward'].mean()
            std = data['reward'].std()
            data['reward'] = (data['reward']-mean)/(std + 1e-16)
        elif sc:
            data['reward'] = (data['reward'] - data['reward'].min()) / \
                (data['reward'].max() - data['reward'].min())

        (state_batch, action_batch, reward_batch, next_state_batch) = (
            data["x_s"], data["action"], args.rew_scale*data["reward"], data["x_t"])
        # self.data_list.append(PairData(
        #         data['edge_index_s'],data["x_s"], args.rew_scale*data["reward"], data["action"], data['edge_index_t'],data["x_t"]))
        state_batch = state_batch.view(size, state_batch.size(0)//size, state_batch.size(1))
        next_state_batch = next_state_batch.view(size, next_state_batch.size(0)//size, next_state_batch.size(1))
        action_batch = action_batch.view(size, action_batch.size(0)//size)
        for i in range(len(state_batch)):
            self.data_list.append(PairData(
                edge_index, state_batch[i], reward_batch[i], action_batch[i], edge_index, next_state_batch[i]))
    
    def sample_batch(self, batch_size=32):
        data = random.sample(self.data_list, batch_size)
        data = Batch.from_data_list(data, follow_batch=['x_s', 'x_t'])
        return data
    


def combined_shape(length, shape=None):
    if shape is None:
        return (length,)
    return (length, shape) if np.isscalar(shape) else (length, *shape)

def cosine_similarity(vec1, vec2):
    dot_product = np.dot(vec1, vec2)
    norm_vec1 = np.linalg.norm(vec1)
    norm_vec2 = np.linalg.norm(vec2)
    if norm_vec1 == 0 or norm_vec2 == 0:  # 防止除以零
        return 0
    similarity = dot_product / (norm_vec1 * norm_vec2)
    return similarity


parser = argparse.ArgumentParser(description='A2C-GNN')

demand_ratio = {'san_francisco': 2, 'washington_dc': 4.2, 'nyc_brooklyn': 9, 'rome': 1.8,
                'shenzhen_downtown_west': 2.5}
json_hr = {'san_francisco': 19, 'washington_dc': 19, 'nyc_brooklyn': 19, 'rome': 8,
           'shenzhen_downtown_west': 8}
beta = {'san_francisco': 0.2, 'washington_dc': 0.5, 'nyc_brooklyn': 0.5, 'porto': 0.1, 'rome': 0.1,
        'shenzhen_downtown_west': 0.5}

test_tstep = {'san_francisco': 3,
              'nyc_brooklyn': 3, 'shenzhen_downtown_west': 3}

# Simulator parameters
parser.add_argument('--seed', type=int, default=1000, metavar='S',
                    help='random seed (default: 10)')
parser.add_argument('--demand_ratio', type=int, default=0.5, metavar='S',
                    help='demand_ratio (default: 0.5)')
parser.add_argument('--json_hr', type=int, default=7, metavar='S',
                    help='json_hr (default: 7)')
parser.add_argument('--json_tstep', type=int, default=3, metavar='S',
                    help='minutes per timestep (default: 3min)')
parser.add_argument('--beta', type=int, default=0.5, metavar='S',
                    help='cost of rebalancing (default: 0.5)')

# Model parameters
parser.add_argument('--test', type=bool, default=True, 
                    help='activates test mode for agent evaluation')
parser.add_argument('--cplexpath', type=str, default="/home/jy/opt/ibm/ILOG/CPLEX_Studio_Community2212",
                    help='defines directory of the CPLEX installation')
parser.add_argument('--directory', type=str, default='saved_files',
                    help='defines directory where to save files')
parser.add_argument('--max_episodes', type=int, default=1, metavar='N',
                    help='number of episodes to train agent (default: 16k)')
parser.add_argument('--max_steps', type=int, default=90, metavar='N',
                    help='number of steps per episode (default: T=60)')
parser.add_argument('--num_data', type=int, default=20, metavar='N')

# parser.add_argument('--n', type=int, default=20, 
#                     help='number of steps')

parser.add_argument('--cuda', type=bool, default=False,
                    help='enables CUDA training')
parser.add_argument("--batch_size", type=int, default=6,
                    help='defines the batch size')
parser.add_argument("--alpha", type=float, default=0.3,
                    help='value of the entropy coefficient')
parser.add_argument("--hidden_size", type=int, default=256,
                    help='number of hidden units in the MLP layer')
parser.add_argument("--checkpoint_path", type=str, default='SAC_nyc_brooklyn',
                    help='path, where to save model checkpoints')
parser.add_argument("--city", type=str, default='rome',
                    help='defines city to train on')

# CQL parameters
parser.add_argument("--load_yaml", type=bool, default=False,
                    help='to load CQL parameters from a yaml file')
parser.add_argument("--memory_path", type=str, default='Replaymemory_nyc_brooklyn_M',
                    help='path, where data is saved')
parser.add_argument("--min_q_weight", type=float, default=5,
                    help='conservatie coeffiecent (eta in paper)')
parser.add_argument("--samples_buffer", type=int, default=12,
                    help='number of samples to take from the dataset')
parser.add_argument("--lagrange_thresh", type=float, default=-1,
                    help='lagrange treshhold tau for automatic tuning of eta')
parser.add_argument("--rew_scale", type=float, default=0.1,
                    help='defines reward scale')
parser.add_argument("--st", type=bool, default=False,
                    help='whether to standardize data')
parser.add_argument("--sc", type=bool, default=False,
                    help='wether to scale data')
args = parser.parse_args()
args.cuda = args.cuda and torch.cuda.is_available()
device = torch.device("cuda" if args.cuda else "cpu")
city = args.city
# Define AMoD Simulator Environment
# Define AMoD Simulator Environment

if not args.test:
    # if args.load_yaml == True:
    #     parameter = omegaconf.load(f"src/conf/config_{city}.yaml")
    #     args.memory = parameter.memory_path
    #     args.min_q_weight = parameter.min_q_weight
    #     args.samples_buffer = parameter.samples_buffer
    #     args.lagrange_thresh = parameter.lagrange_thresh
    #     args.rew_scale = parameter.rew_scale
    #     args.max_episodes = parameter.max_episodes

    # scenario = Scenario(json_file=f"data/toy_example.json", demand_ratio=demand_ratio[city],
    #                     json_hr=json_hr[city], sd=args.seed, json_tstep=args.json_tstep, tf=args.max_steps)

    data = pd.read_csv('data/network_data.csv')
    scenario = Scenario(data, tf = args.max_steps, seed = args.seed)
    env = AMoD(scenario, data, beta=1, tf = args.max_steps, totalAcc = 220)

    # acc = data['totalAcc'][0]['acc']
    edge_index = torch.tensor([data['i'].values, data['j'].values]).long()
    # edge_index = torch.vstack((torch.tensor([edge['i'] for edge in topology]).view(1, -1),
    #                            torch.tensor([edge['j'] for edge in topology]).view(1, -1))).long()
    N = data['i'].nunique()

    # # Initialize Dataset
    # scenario_history = Scenario_History(json_file=f"data/scenario_{city}.json", demand_ratio=demand_ratio[city],
    #                             json_hr=json_hr[city], sd=args.seed, json_tstep=args.json_tstep, tf=args.max_steps)
    # Initialize Dataset
    num = args.max_steps+10

    rebTime = {(row['i'], row['j']): row['cost'] for _, row in data.iterrows()}
    demandTime = {(row['i'], row['j']): row['travel_time'] for _, row in data.iterrows()}
    demand_input = {(row['i'], row['j']): row['demand'] for _, row in data.iterrows()}

    price_ls = [data['price'] for _ in range(num)]
    cost_ls = [data['cost'] for _ in range(num)]
    demandTime_ls = [data['travel_time'] for _ in range(num)]
    pairs = list(zip(data['i'], data['j']))

    para = np.load('parameter_toy.npz')
    theta_f = [para['theta_f']]
    theta_g = [para['theta_g']]
    mu = [para['mu']]
    
    # with open('toy_theta_f_cvxpy.pkl', 'rb') as f:
    #     theta_f = pickle.load(f)
    # with open('toy_theta_g_cvxpy.pkl', 'rb') as f:
    #     theta_g = pickle.load(f)
    # with open('toy_mu_cvxpy.pkl', 'rb') as f:
    #     mu = pickle.load(f)

    # Initialize lists for logging
    log = {'train_reward': [],
           'train_served_demand': [],
           'train_reb_cost': []}
    train_episodes = args.max_episodes  # set max number of training episodes
    T = args.max_steps  # set episode length
    epochs = trange(train_episodes, dynamic_ncols=True)  # epoch iterator
    

    # # Model 2 constraints + regularized opt
    #  # Training Loop
    best_reward = -np.inf  # set best reward
    log = {'train_reward': [],
           'train_served_demand': [],
           'train_reb_cost': []}
    # Initialize A2C-GNN
    # model = A2C(env=env, input_size=21).to(device)
    model = A2C(env=env, input_size=21).to(device)
    # model.load_checkpoint(path=f"ckpt/modeltoy_0424_145054.csv.pth")

    model.train()  # set model in train mode
    
    for step in range(train_episodes):
        episode_reward, episode_served_demand, episode_rebalancing_cost, desiredAcc = model.test_agent(
                1, env, args.cplexpath, args.directory, theta_f, theta_g, mu, 2, args.max_steps, edge_index, cost_ls, price_ls, demandTime_ls)

        epochs.set_description(
            f"Episode {step} | Reward: {episode_reward:.2f} | ServedDemand: {episode_served_demand:.2f} | Reb. Cost: {episode_rebalancing_cost:.2f}")
        if step % 100 == 0:
            benchmark_obj = Benchmark(env, args.max_steps, N, cost_ls, price_ls, demandTime_ls)
            # mpc_obj = MPC(env, args.max_steps, N)
            print('The performance ratio of Model 2 vs benchmark is:{}'.format(np.round(episode_reward/benchmark_obj, 2)))
            # print('The performance ratio of Model 2 vs MPC is:{}'.format(np.round(episode_reward/mpc_obj, 2)))
            print('desiredACC:', desiredAcc)

        timestamp = datetime.now().strftime("%m%d_%H%M%S")
        filename = f"modeltoy2_{timestamp}.csv"
        model.training_step()
        # Checkpoint best performing modxel
        if episode_reward >= best_reward:
            model.save_checkpoint(
                path=f"ckpt/{filename}.pth")
            
            best_reward = episode_reward
        
        log['train_reward'].append(episode_reward)
        log['train_served_demand'].append(episode_served_demand)
        log['train_reb_cost'].append(episode_rebalancing_cost)

    plt.plot(np.array(log['train_reward']))
    plt.title('Rewards per Episode')
    plt.xlabel('Episode')
    plt.ylabel('Total Reward')
    plt.grid(True)
    plt.savefig('Rewards_per_Episode2.png')

        
else:
    data = pd.read_csv('data/network_data.csv')
    scenario = Scenario(data, tf = args.max_steps, seed = args.seed)
    env = AMoD(scenario, data, beta=1, tf = args.max_steps, totalAcc = 220)

    edge_index = torch.tensor([data['i'].values, data['j'].values]).long()
    N = data['i'].nunique()

    # # Initialize Dataset
    num = args.max_steps+10

    rebTime = {(row['i'], row['j']): row['cost'] for _, row in data.iterrows()}
    demandTime = {(row['i'], row['j']): row['travel_time'] for _, row in data.iterrows()}
    demand_input = {(row['i'], row['j']): row['demand'] for _, row in data.iterrows()}

    price_ls = [data['price'] for _ in range(num)]
    cost_ls = [data['cost'] for _ in range(num)]
    demandTime_ls = [data['travel_time'] for _ in range(num)]
    pairs = list(zip(data['i'], data['j']))


    with open('toy_theta_f_cvxpy.pkl', 'rb') as f:
        theta_f = pickle.load(f)
    with open('toy_theta_g_cvxpy.pkl', 'rb') as f:
        theta_g = pickle.load(f)
    with open('toy_mu_cvxpy.pkl', 'rb') as f:
        mu = pickle.load(f)
   
    N = 4
    M = len(theta_f[0])
    A = np.zeros((N,M))
    n = 0
    for i,j in pairs:
        A[i,n] = 1
        A[j,n] = -1
        n = n + 1

    path_ls = ['modeltoy_0424_150727.csv.pth', 'modeltoy2_0425_194715.csv.pth']
    q_exp_ls = [[],[]]
    for a in [0,1]:
        path = path_ls[a]
        model = A2C(env=env, input_size=21).to(device)

        # Load pre-trained model
        # model.load_checkpoint(path=f"ckpt/modeltoy_0424_150727.csv.pth")
        model.load_checkpoint(path=f"ckpt/{path}")
        # model.load_checkpoint(path = f"ckpt/modeltoy_1.pth")

        test_episodes = args.max_episodes  # set max number of training episodes
        T = args.max_steps  # set episode length
        epochs = trange(test_episodes)  # epoch iterator
        # Initialize lists for logging
        log = {'test_reward': [],
            'test_served_demand': [],
            'test_reb_cost': []}

        episode_rewards = []
        ratio = []
        ratio_MPC = []

        price = defaultdict(dict)
        ind = 0
        for i, j in env.edges:
            # print(i,j)
            for t in range(args.max_steps+10):
                # print(t)
                price[i,j][t] = price_ls[t][ind]
            ind += 1
        env.price = price
        served_demand = []
        MPC_demand = []
        MPC_reward = []
        q_current_mpc = []
        for episode in range(5):
            episode_reward = 0
            episode_served_demand = 0
            episode_rebalancing_cost = 0
            obs = env.reset()
            done = False
            actions = []
            result = {}
            f_ls, g_ls, qhat_ls = [], [], []
            while (not done):

                obs = env.Initial_step()
                action_rl = model.select_action(
                    obs, edge_index=False)
                actions.append(action_rl)

                desiredAcc = {env.region[i]: int(
                    action_rl[i] * dictsum(env.acc, env.time + 1))for i in range(len(env.region))}
                sorted_dict = dict(sorted(desiredAcc.items()))            
                desiredAcc = np.array(list(sorted_dict.values()))
                qhat_ls.append(desiredAcc)
                # print('desiredACC:', desiredAcc)

                # paxAction, rebAction = solveOpt(env, desiredAcc)
                if a == 0:
                    # print('first opt:')
                    paxAction, rebAction = solveOpt(env, desiredAcc, cost_ls, price_ls, demandTime_ls)
                elif a == 1:
                    # print('second opt:')
                    paxAction, rebAction = RegsolveOpt(env, desiredAcc, theta_f, theta_g, mu, cost_ls, price_ls, demandTime_ls)
                _, reward, done, info, _, _ = env.step(paxAction, rebAction, args.max_steps, cost_ls, price_ls, demandTime_ls)
                result[env.time-1] = [paxAction, rebAction, desiredAcc]
                f_ls.append(rebAction)
                g_ls.append(paxAction)

                episode_reward += reward
                # track performance over episode
                episode_served_demand += info['served_demand']
                episode_rebalancing_cost += info['rebalancing_cost']
 
            benchmark_obj = Benchmark(env, args.max_steps, N, cost_ls, price_ls, demandTime_ls)
            mpc_obj, f_ls0, g_ls0, q_ls0 = MPC(env, args.max_steps, N, cost_ls, price_ls, demandTime_ls, demand_input)
            print('The objective value of benckmark in episode {} is:{}'.format(episode+1, benchmark_obj))

            # q_current_mpc = np.array(q_ls0).transpose().tolist()
            n = len(q_ls0[0])
            # q_exp_ls = [[env.acc[i][t] for t in range(T)] for i in range(n)]
            q_exp_ls[a].append([[env.acc[i][t] for t in range(T)] for i in range(n)])
            q_current_mpc.append(np.array(q_ls0).transpose().tolist())


            MPC_demand.append(int(np.sum(g_ls0)))
            MPC_reward.append(mpc_obj)

            # Send current statistics to screen
            epochs.set_description(
                f"Episode {episode+1} | Reward: {episode_reward:.2f} | ServedDemand: {episode_served_demand:.2f} | Reb. Cost: {episode_rebalancing_cost}")
            episode_rewards.append(int(episode_reward))
            ratio.append(np.round(episode_reward/benchmark_obj,2))
            ratio_MPC.append(np.round(episode_reward/mpc_obj,2))
            served_demand.append(int(episode_served_demand))

            cos_sim_ls, diff_ls = [], []
            cos_sim2_ls, diff2_ls = [], []
            cos_sim3_ls, diff3_ls = [], []
            for i in range(len(f_ls)):
            # for i in range(train_num):
                vec1 = f_ls[i]
                vec2 = f_ls0[i]
                cos_sim = np.round(cosine_similarity(vec1, vec2),3)
                diff = np.absolute(vec1-vec2).sum()
                cos_sim_ls.append(cos_sim)
                diff_ls.append(diff)

                #print('similarity of f in time step ', i, ':', cos_sim, diff)

                vec1 = g_ls[i]
                vec2 = g_ls0[i]
                cos_sim2 = np.round(cosine_similarity(vec1, vec2),3)
                diff2 = np.absolute(vec1-vec2).sum()
                cos_sim2_ls.append(cos_sim2)
                diff2_ls.append(diff2)

                qhat_mpc = q_ls0[i] - np.dot(A, np.array(f_ls0[i]) + np.array(g_ls0[i]))
                cos_sim3 = np.round(cosine_similarity(qhat_ls[i], qhat_mpc),3)
                cos_sim3_ls.append(cos_sim3)
                diff3 = np.absolute(qhat_ls[i]-qhat_mpc).sum()
                diff3_ls.append(diff3)


            l = len(cos_sim_ls)
            print('f:', sum(cos_sim_ls)/l, sum(diff_ls)/l)
            print('g:', sum(cos_sim2_ls)/l, sum(diff2_ls)/l)
            print('q_hat:', sum(cos_sim3_ls)/l, sum(diff3_ls)/l)

        print(path)
        print('reward of model:', episode_rewards, np.mean(np.array(episode_rewards)))
        print('reward of MPC:', MPC_reward, np.mean(np.array(MPC_reward)))

        print(ratio, np.mean(np.array(ratio)))
        print(ratio_MPC, np.mean(np.array(ratio_MPC)))
        print('served demand of model:', served_demand, np.mean(np.array(served_demand)))
        print('served demand of MPC:', MPC_demand, np.mean(np.array(MPC_demand)))
    
    n = len(q_ls0[0])
    for episode in range(5):
        fig, axes = plt.subplots(n,1,figsize=(6,3*n))
        for i in range(n):
            axes[i].plot(q_current_mpc[episode][i][:50], label = 'mpc')
            axes[i].plot(q_exp_ls[0][episode][i][:50], label = 'unchanged')
            axes[i].plot(q_exp_ls[1][episode][i][:50], label = 'inv_opt')
            axes[i].legend()
            axes[i].set_title('Number of Vehciles in the {}th Node'.format(i))
        plt.tight_layout()
        plt.show()

