from __future__ import print_function
import argparse
from tqdm import trange
import numpy as np
import torch

from src.envs.amod_env import Scenario, AMoD, Scenario_History
from src.misc.utils import dictsum
import random
import pickle
from torch_geometric.data import Data, Batch
import json
import omegaconf
# from generate_historical_data import Generate_History
from generate_historical_data import Benchmark, MPC
from src.algos.opt_solver import solveOpt, RegsolveOpt
from collections import defaultdict
from src.algos.a2c_gnn import A2C
import matplotlib.pyplot as plt
from matplotlib.backend_bases import ResizeEvent
import warnings
import matplotlib
import time

class PairData(Data):
    """
    Store 2 graphs in one Data object (s_t and s_t+1)
    """

    def __init__(self, edge_index_s=None, x_s=None, reward=None, action=None, edge_index_t=None, x_t=None):
        super().__init__()
        self.edge_index_s = edge_index_s
        self.x_s = x_s
        self.reward = reward
        self.action = action
        self.edge_index_t = edge_index_t
        self.x_t = x_t

    def __inc__(self, key, value, *args, **kwargs):
        if key == 'edge_index_s':
            return self.x_s.size(0)
        if key == 'edge_index_t':
            return self.x_t.size(0)
        else:
            return super().__inc__(key, value, *args, **kwargs)


class ReplayData:
    """
    Replay buffer for SAC agents
    """

    def __init__(self, device, rew_scale):
        self.device = device
        self.data_list = []
        self.rew_scale = rew_scale

    def create_dataset(self, edge_index, memory_path, size=60000, st=False, sc=False, store_path=''):
        # w = open(f'replaymemories/{memory_path}.pkl', "rb")

        w = open(store_path, "rb")
        replay_buffer = pickle.load(w)
        data = replay_buffer.sample_batch(size)

        if st:
            mean = data['reward'].mean()
            std = data['reward'].std()
            data['reward'] = (data['reward']-mean)/(std + 1e-16)
        elif sc:
            data['reward'] = (data['reward'] - data['reward'].min()) / \
                (data['reward'].max() - data['reward'].min())

        (state_batch, action_batch, reward_batch, next_state_batch) = (
            data["x_s"], data["action"], args.rew_scale*data["reward"], data["x_t"])

        state_batch = state_batch.view(size, state_batch.size(0)//size, state_batch.size(1))
        next_state_batch = next_state_batch.view(size, next_state_batch.size(0)//size, next_state_batch.size(1))
        action_batch = action_batch.view(size, action_batch.size(0)//size)
        for i in range(len(state_batch)):
            self.data_list.append(PairData(
                edge_index, state_batch[i], reward_batch[i], action_batch[i], edge_index, next_state_batch[i]))
    
    def sample_batch(self, batch_size=32):
        data = random.sample(self.data_list, batch_size)
        data = Batch.from_data_list(data, follow_batch=['x_s', 'x_t'])
        return data
    


def combined_shape(length, shape=None):
    if shape is None:
        return (length,)
    return (length, shape) if np.isscalar(shape) else (length, *shape)

def cosine_similarity(vec1, vec2):
    dot_product = np.dot(vec1, vec2)
    norm_vec1 = np.linalg.norm(vec1)
    norm_vec2 = np.linalg.norm(vec2)
    if norm_vec1 == 0 or norm_vec2 == 0:  # 防止除以零
        return 0
    similarity = dot_product / (norm_vec1 * norm_vec2)
    return similarity


parser = argparse.ArgumentParser(description='A2C-GNN')

demand_ratio = {'san_francisco': 2, 'washington_dc': 4.2, 'nyc_brooklyn': 9, 'rome': 1.8,
                'shenzhen_downtown_west': 2.5}
json_hr = {'san_francisco': 19, 'washington_dc': 19, 'nyc_brooklyn': 19, 'rome': 8,
           'shenzhen_downtown_west': 8}
test_tstep = {'san_francisco': 3,
              'nyc_brooklyn': 3, 'shenzhen_downtown_west': 3}
beta = {'san_francisco': 0.2, 'washington_dc': 0.5, 'nyc_brooklyn': 0.5, 'porto': 0.1, 'rome': 0.1,
        'shenzhen_downtown_west': 0.5}


# Simulator parameters
parser.add_argument('--seed', type=int, default=1000, metavar='S',
                    help='random seed (default: 10)')
parser.add_argument('--demand_ratio', type=int, default=0.5, metavar='S',
                    help='demand_ratio (default: 0.5)')
parser.add_argument('--json_hr', type=int, default=7, metavar='S',
                    help='json_hr (default: 7)')
parser.add_argument('--json_tstep', type=int, default=3, metavar='S',
                    help='minutes per timestep (default: 3min)')
# Model parameters
parser.add_argument("--city", type=str, default='washington_dc',
                    help='defines city to train on')

parser.add_argument('--test', type=bool, default=True,
                    help='activates test mode for agent evaluation')
parser.add_argument('--cplexpath', type=str, default="/Users/yuexuanwang/opt/CPLEX_Studio2211/opl/bin/x86-64_osx/",
                    help='defines directory of the CPLEX installation')
parser.add_argument('--directory', type=str, default='saved_files',
                    help='defines directory where to save files')
parser.add_argument('--max_episodes', type=int, default=1, metavar='N',
                    help='number of episodes to train agent (default: 16k)')
parser.add_argument('--max_steps', type=int, default=50, metavar='N',
                    help='number of steps per episode (default: T=60)')

parser.add_argument('--cuda', type=bool, default=False,
                    help='enables CUDA training')
parser.add_argument("--batch_size", type=int, default=6,
                    help='defines the batch size')
parser.add_argument("--alpha", type=float, default=0.3,
                    help='value of the entropy coefficient')
parser.add_argument("--hidden_size", type=int, default=256,
                    help='number of hidden units in the MLP layer')
parser.add_argument("--checkpoint_path", type=str, default='SAC_nyc_brooklyn',
                    help='path, where to save model checkpoints')


# CQL parameters
parser.add_argument("--load_yaml", type=bool, default=False,
                    help='to load CQL parameters from a yaml file')
parser.add_argument("--memory_path", type=str, default='Replaymemory_nyc_brooklyn_M',
                    help='path, where data is saved')
parser.add_argument("--min_q_weight", type=float, default=5,
                    help='conservatie coeffiecent (eta in paper)')
parser.add_argument("--samples_buffer", type=int, default=12,
                    help='number of samples to take from the dataset')
parser.add_argument("--lagrange_thresh", type=float, default=-1,
                    help='lagrange treshhold tau for automatic tuning of eta')
parser.add_argument("--rew_scale", type=float, default=0.1,
                    help='defines reward scale')
parser.add_argument("--st", type=bool, default=False,
                    help='whether to standardize data')
parser.add_argument("--sc", type=bool, default=False,
                    help='wether to scale data')
args = parser.parse_args()
args.cuda = args.cuda and torch.cuda.is_available()
device = torch.device("cuda" if args.cuda else "cpu")

city = args.city

if not args.test:

    scenario = Scenario(json_file=f"data/scenario_{city}.json", demand_ratio=demand_ratio[city],
                        json_hr=json_hr[city], sd=args.seed, json_tstep=args.json_tstep, tf=args.max_steps)

    env = AMoD(scenario, beta=beta[city])

    with open(f'data/scenario_{city}.json', "r") as file:
        data = json.load(file)
    # acc = data['totalAcc'][0]['acc']
    edge_index = torch.vstack((torch.tensor([edge['i'] for edge in data["topology_graph"]]).view(1, -1),
                               torch.tensor([edge['j'] for edge in data["topology_graph"]]).view(1, -1))).long()
    i_values = [entry['i'] for entry in data['topology_graph']]
    N = len(set(i_values))
   
    # Initialize Dataset
    tripAttr = scenario.tripAttr
    rebTime = scenario.rebTime
    demandTime = scenario.demandTime
    demand_input = scenario.demand_input
    num = scenario.num

     # Store the static cost vectors
    price_ls = [[] for _ in range(num)]
    cost_ls = [[] for _ in range(num)]
    demandTime_ls = [[] for _ in range(num)]

    # print('demandTime_avg:', demandTime_avg)
    pairs = []
    for i, j, t, d, p in tripAttr:
        if t < num:
            price_ls[t].append(p)
            if (i,j) not in pairs:
                pairs.append((i,j))     
            cost_ls[t].append(rebTime[i,j][t])
            if (i, j) not in demandTime or t not in demandTime[i, j]:
                demandTime[i, j][t] = 0
            demandTime_ls[t].append(demandTime[i,j][t])

    price_avg = np.round(np.mean(np.stack(price_ls), axis=0))
    price_ls = [price_avg for _ in range(num)]
    demandTime_ls = [np.round(np.mean(np.stack(demandTime_ls), axis=0)) for _ in range(num)]

    with open('wc_theta_f.pkl', 'rb') as f:
        theta_f = pickle.load(f)
    with open('wc_theta_g.pkl', 'rb') as f:
        theta_g = pickle.load(f)
    with open('wc_mu.pkl', 'rb') as f:
        mu = pickle.load(f)

    # Initialize lists for logging
    log = {'train_reward': [],
           'train_served_demand': [],
           'train_reb_cost': []}
    train_episodes = args.max_episodes  # set max number of training episodes
    T = args.max_steps  # set episode length
    epochs = trange(train_episodes, dynamic_ncols=True)  # epoch iterator
    

    # # Model 2 constraints + regularized opt
    #  # Training Loop
    best_reward = -np.inf  # set best reward
    log = {'train_reward': [],
           'train_served_demand': [],
           'train_reb_cost': []}
    # Initialize A2C-GNN
    model = A2C(env=env, input_size=21).to(device)

    model.train()  # set model in train mode
    
    for step in range(train_episodes):
        episode_reward, episode_served_demand, episode_rebalancing_cost, desiredAcc = model.test_agent(
                1, env, args.cplexpath, args.directory, theta_f, theta_g, mu, 2, args.max_steps, edge_index, cost_ls, price_ls, demandTime_ls, beta[city])

        epochs.set_description(
            f"Episode {step} | Reward: {episode_reward:.2f} | ServedDemand: {episode_served_demand:.2f} | Reb. Cost: {episode_rebalancing_cost:.2f}")
        if step % 500 == 0:
            benchmark_obj = Benchmark(env, args.max_steps, N, cost_ls, price_ls, demandTime_ls, beta[city])
            print('The performance ratio of Model 2 vs benchmark is:{}'.format(np.round(episode_reward/benchmark_obj, 2)))
            # print('The performance ratio of Model 2 vs MPC is:{}'.format(np.round(episode_reward/mpc_obj, 2)))
            print('desiredACC:', desiredAcc[:20,:])

        # perform on-policy backprop
        model.training_step()
        # Checkpoint best performing modxel
        if episode_reward >= best_reward:
            model.save_checkpoint(
                path=f"ckpt/model333_wc.pth")
            best_reward = episode_reward
        
        log['train_reward'].append(episode_reward)
        log['train_served_demand'].append(episode_served_demand)
        log['train_reb_cost'].append(episode_rebalancing_cost)

    plt.plot(np.array(log['train_reward']))
    plt.title('Rewards per Episode')
    plt.xlabel('Episode')
    plt.ylabel('Total Reward')
    plt.grid(True)
    plt.savefig('Rewards_per_Episode2.png')

        
else:
    with open('wc_theta_f.pkl', 'rb') as f:
        theta_f = pickle.load(f)
    with open('wc_theta_g.pkl', 'rb') as f:
        theta_g = pickle.load(f)
    with open('wc_mu.pkl', 'rb') as f:
        mu = pickle.load(f)

    # test pre-trained model
    if city == 'nyc_brooklyn':
        scenario = Scenario(json_file=f"data/scenario_{city}.json", demand_ratio=demand_ratio[city],
                            json_hr=json_hr[city], sd=args.seed, json_tstep=test_tstep[city], tf=args.max_steps)
    else:
        scenario = Scenario(json_file=f"data/scenario_{city}.json", demand_ratio=demand_ratio[city],
                            json_hr=json_hr[city], sd=args.seed, json_tstep=args.json_tstep, tf=args.max_steps)

    env = AMoD(scenario, beta=beta[city])

    # Initialize Dataset
    tripAttr = scenario.tripAttr
    rebTime = scenario.rebTime
    demandTime = scenario.demandTime
    demand_input = scenario.demand_input
    num = scenario.num

    # Store the static cost vectors
    price_ls = [[] for _ in range(num)]
    cost_ls = [[] for _ in range(num)]
    demandTime_ls = [[] for _ in range(num)]

    pairs = []
    for i, j, t, d, p in tripAttr:
        if t < num:
            price_ls[t].append(p)
            if (i,j) not in pairs:
                pairs.append((i,j))     
            cost_ls[t].append(rebTime[i,j][t])
            if (i, j) not in demandTime or t not in demandTime[i, j]:
                demandTime[i, j][t] = 0
            if (i, j) not in demand_input or t not in demand_input[i, j]:
                demand_input[i, j][t] = 0
            demandTime_ls[t].append(demandTime[i,j][t])

    price_avg = np.round(np.mean(np.stack(price_ls), axis=0))
    price_ls = [price_avg for _ in range(num)]
    demandTime_ls = [np.round(np.mean(np.stack(demandTime_ls), axis=0)) for _ in range(num)]

    nodes = set()
    for i, j in pairs:
        nodes.add(i)
        nodes.add(j)
    N = len(nodes)
    M = len(theta_f[0])
    A = np.zeros((N,M))
    n = 0
    for i, j in pairs:
        A[i,n] = 1
        A[j,n] = -1
        n = n + 1

    with open(f'data/scenario_{city}.json', "r") as file:
        data = json.load(file)
    # acc = data['totalAcc'][0]['acc']
    edge_index = torch.vstack((torch.tensor([edge['i'] for edge in data["topology_graph"]]).view(1, -1),
                               torch.tensor([edge['j'] for edge in data["topology_graph"]]).view(1, -1))).long()
    
    # path_ls = ['model1_wc.pth', 'model2_wc.pth']
    path_ls = ['model1_wc.pth', 'model2_wc.pth']
    for a in [0, 1]:
        path = path_ls[a]
        model = A2C(env=env, input_size=21).to(device)

        # Load pre-trained model
        model.load_checkpoint(path=f"ckpt/"+path_ls[a])

        test_episodes = args.max_episodes  # set max number of training episodes
        T = args.max_steps  # set episode length
        epochs = trange(test_episodes)  # epoch iterator
        # Initialize lists for logging
        log = {'test_reward': [],
            'test_served_demand': [],
            'test_reb_cost': []}

        episode_rewards = []
        ratio = []
        ratio_MPC = []
        MPC_demand, MPC_reward = [], []
        served_demand = []
        for episode in range(5):
            episode_reward = 0
            episode_served_demand = 0
            episode_rebalancing_cost = 0
            obs = env.reset()
            done = False
            actions = []
            tripAttr = scenario.tripAttr
            rebTime = scenario.rebTime
            result = {}
            f_ls, g_ls, qhat_ls = [], [], []
            start_time = time.time()
            while (not done):

                obs = env.Initial_step()
                # o = model.parse_obs(obs, model.device)

                # action_rl = model.select_action(
                #     o.x, o.edge_index, deterministic=True)
                # action_rl = model.select_action(
                #     obs, edge_index=False)
                action_rl = model.select_action(
                    obs, edge_index)
                # actions.append(action_rl)

                desiredAcc = {env.region[i]: int(
                    action_rl[i] * dictsum(env.acc, env.time + 1))for i in range(len(env.region))}
                sorted_dict = dict(sorted(desiredAcc.items()))            
                desiredAcc = np.array(list(sorted_dict.values()))
                # qhat_ls.append(desiredAcc)
                print('desiredACC:', desiredAcc)

                if a == 0:
                    paxAction, rebAction = solveOpt(env, desiredAcc, cost_ls, price_ls, demandTime_ls, beta[city])
                elif a == 1:
                    paxAction, rebAction = RegsolveOpt(env, desiredAcc, theta_f, theta_g, mu, cost_ls, price_ls, demandTime_ls, beta[city])
                _, reward, done, info, _, _ = env.step(paxAction, rebAction, args.max_steps, cost_ls, price_ls, demandTime_ls)
                # result[env.time-1] = [paxAction, rebAction, desiredAcc]
                # f_ls.append(rebAction)
                # g_ls.append(paxAction)

                episode_reward += reward
                # track performance over episode
                episode_served_demand += info['served_demand']
                episode_rebalancing_cost += info['rebalancing_cost']
            end_time = time.time()
            print(f"test episode running time: {end_time-start_time} 秒")
            with open(f'data/scenario_{city}.json', "r") as file:
                data = json.load(file)
            i_values = [entry['i'] for entry in data['topology_graph']]
            N = len(set(i_values))
            benchmark_obj = Benchmark(env, args.max_steps, N, cost_ls, price_ls, demandTime_ls, beta[city])
            
            start_time = time.time()
            mpc_obj, f_ls0, g_ls0, q_ls0 = MPC(env, args.max_steps, N, cost_ls, price_ls, demandTime_ls, demand_input, beta[city])
            end_time = time.time()
            print(f"mpc running time: {end_time-start_time} 秒")

            print('The objective value of benckmark in episode {} is:{}'.format(episode+1, benchmark_obj))
            MPC_demand.append(int(np.sum(g_ls0)))
            MPC_reward.append(mpc_obj)
            served_demand.append(int(episode_served_demand))
            # Send current statistics to screen
            epochs.set_description(
                f"Episode {episode+1} | Reward: {episode_reward:.2f} | ServedDemand: {episode_served_demand:.2f} | Reb. Cost: {episode_rebalancing_cost}")
            episode_rewards.append(int(episode_reward))
            ratio.append(np.round(episode_reward/benchmark_obj,2))
            ratio_MPC.append(np.round(episode_reward/mpc_obj,2))

            cos_sim_ls, diff_ls = [], []
            cos_sim2_ls, diff2_ls = [], []
            cos_sim3_ls, diff3_ls = [], []
            for i in range(len(f_ls)):
            # for i in range(train_num):
                vec1 = f_ls[i]
                vec2 = f_ls0[i]
                cos_sim = np.round(cosine_similarity(vec1, vec2),3)
                diff = np.absolute(vec1-vec2).sum()
                cos_sim_ls.append(cos_sim)
                diff_ls.append(diff)

                #print('similarity of f in time step ', i, ':', cos_sim, diff)

                vec1 = g_ls[i]
                vec2 = g_ls0[i]
                cos_sim2 = np.round(cosine_similarity(vec1, vec2),3)
                diff2 = np.absolute(vec1-vec2).sum()
                cos_sim2_ls.append(cos_sim2)
                diff2_ls.append(diff2)

                qhat_mpc = q_ls0[i] - np.dot(A, np.array(f_ls0[i]) + np.array(g_ls0[i]))
                cos_sim3 = np.round(cosine_similarity(qhat_ls[i], qhat_mpc),3)
                cos_sim3_ls.append(cos_sim3)
                diff3 = np.absolute(qhat_ls[i]-qhat_mpc).sum()
                diff3_ls.append(diff3)

                #print('similarity of g in time step ', i, ':', cos_sim2, diff2)
                #print('--------------------------------------')

            l = len(cos_sim_ls)
            # print('f:', sum(cos_sim_ls)/l, sum(diff_ls)/l)
            # print('g:', sum(cos_sim2_ls)/l, sum(diff2_ls)/l)
            # print('q_hat:', sum(cos_sim3_ls)/l, sum(diff3_ls)/l)

        print(path)
        print(episode_rewards, np.mean(np.array(episode_rewards)))
        print(ratio, np.mean(np.array(ratio)))
        print(ratio_MPC, np.mean(np.array(ratio_MPC)))

        print('served demand of MPC:', MPC_demand, np.mean(np.array(MPC_demand)))
        print('served demand of model:', served_demand, np.mean(np.array(served_demand)))

        print('reward of MPC:', MPC_reward, np.mean(np.array(MPC_reward)))
        print('reward of model:', episode_rewards, np.mean(np.array(episode_rewards)))


