from __future__ import print_function
import argparse
from tqdm import trange
import numpy as np
import torch

from src.envs.amod_env import Scenario, AMoD
from src.misc.utils import dictsum
import random
import pickle
from torch_geometric.data import Data, Batch
from generate_historical_data import Benchmark, MPC
from collections import defaultdict
from a2c_gnn_e2e import A2C                                    
import matplotlib.pyplot as plt
import pandas as pd
from datetime import datetime

# ---------------- Utils ----------------
def cosine_similarity(vec1, vec2):
    dot_product = np.dot(vec1, vec2)
    norm_vec1 = np.linalg.norm(vec1)
    norm_vec2 = np.linalg.norm(vec2)
    if norm_vec1 == 0 or norm_vec2 == 0:
        return 0
    return dot_product / (norm_vec1 * norm_vec2)

# ---------------- Argparse ---------------- 1
parser = argparse.ArgumentParser(description='A2C-GNN End-to-End')

parser.add_argument('--seed', type=int, default=1000)
parser.add_argument('--test', type=bool, default=True)
parser.add_argument('--max_episodes', type=int, default=5)
parser.add_argument('--max_steps', type=int, default=90)
parser.add_argument('--cuda', type=bool, default=True)
parser.add_argument("--city", type=str, default='rome')
args = parser.parse_args()

args.cuda = args.cuda and torch.cuda.is_available()
device = torch.device("cuda" if args.cuda else "cpu")

# ---------------- Env init ----------------
data = pd.read_csv('data/network_data.csv')
scenario = Scenario(data, tf=args.max_steps, seed=args.seed)
env = AMoD(scenario, data, beta=1, tf=args.max_steps, totalAcc=220)
edge_index = torch.tensor([data['i'].values, data['j'].values]).long()
N = data['i'].nunique()
num = args.max_steps + 10

price_ls = [data['price'] for _ in range(num)]
cost_ls = [data['cost'] for _ in range(num)]
demandTime_ls = [data['travel_time'] for _ in range(num)]

# ---------------- Training ----------------
if not args.test:
    log = {'train_reward': [], 'train_served_demand': [], 'train_reb_cost': []}
    train_episodes = args.max_episodes
    T = args.max_steps
    epochs = trange(train_episodes, dynamic_ncols=True)

    best_reward = -np.inf
    model = A2C(env=env, input_size=21, edge_index=edge_index, device=device).to(device)
    model.train()
    max_reward = -np.inf

    for step in range(train_episodes):
        obs = env.reset()
        done = False
        eps_reward, eps_served, eps_rebcost = 0, 0, 0

        while not done:
            obs = env.Initial_step()                          
            paxAction, rebAction = model.select_action(obs)
            _, reward, done, info, _, _ = env.step(paxAction, rebAction, args.max_steps,
                                                   cost_ls, price_ls, demandTime_ls)
            model.rewards.append(reward)
            eps_reward += reward
            eps_served += info['served_demand']
            eps_rebcost += info['rebalancing_cost']

        # update A2C
        model.training_step()

        # logging
        epochs.set_description(
            f"Episode {step} | Reward: {eps_reward:.2f} | ServedDemand: {eps_served:.2f} | Reb. Cost: {eps_rebcost:.2f}"
        )
        if eps_reward >= best_reward:
            model.save_checkpoint(path=f"ckpt/model_e2e_best.pth")
            best_reward = eps_reward
        log['train_reward'].append(eps_reward)
        log['train_served_demand'].append(eps_served)
        log['train_reb_cost'].append(eps_rebcost)

    # plot reward curve
    print(f"Max Reward: {max_reward}")
    plt.plot(np.array(log['train_reward']))
    plt.title('Rewards per Episode (End-to-End)')
    plt.xlabel('Episode')
    plt.ylabel('Total Reward')
    plt.grid(True)
    plt.savefig('Rewards_per_Episode_e2e.png')

# ---------------- Testing ----------------
else:
    model = A2C(env=env, input_size=21, edge_index=edge_index, device=device).to(device)
    model.load_checkpoint(path="ckpt/model_e2e_best.pth")   # 替换成实际保存的模型文件
    test_episodes = args.max_episodes
    T = args.max_steps
    epochs = trange(test_episodes)

    for ep in epochs:
        obs = env.reset()
        done = False
        ep_reward, ep_served, ep_rebcost = 0, 0, 0
        while not done:
            obs = env.Initial_step()
            paxAction, rebAction = model.select_action(obs)
            _, reward, done, info, _, _ = env.step(paxAction, rebAction, args.max_steps,
                                                   cost_ls, price_ls, demandTime_ls)
            ep_reward += reward
            ep_served += info['served_demand']
            ep_rebcost += info['rebalancing_cost']
        print(f"Test Episode {ep}: Reward={ep_reward}, Served={ep_served}, RebCost={ep_rebcost}")
