import argparse
import numpy as np
import torch
from tqdm import trange
import json
from src.envs.amod_env import Scenario, AMoD
from a2c_gnn_e2e import A2C, GNNParser


parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, default=10)
parser.add_argument("--max_episodes", type=int, default=5)
parser.add_argument("--max_steps", type=int, default=50)
parser.add_argument("--hidden_size", type=int, default=64)
parser.add_argument("--city", type=str, default="washington_dc")
parser.add_argument('--test', type=bool, default=True, 
                    help='activates test mode for agent evaluation')
args = parser.parse_args()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
city = args.city


scenario = Scenario(
    json_file=f"data/scenario_{city}.json",
    demand_ratio=1.0,
    json_hr=19,
    sd=args.seed,
    json_tstep=3,
    tf=args.max_steps,
)
env = AMoD(scenario, beta=0.5)


with open(f"data/scenario_{city}.json", "r") as file:
    data = json.load(file)
edge_index = torch.vstack(
    (
        torch.tensor([edge["i"] for edge in data["topology_graph"]]),
        torch.tensor([edge["j"] for edge in data["topology_graph"]]),
    )
).long()


model = A2C(env=env, input_size=3, edge_index=edge_index, hidden_size=args.hidden_size, device=device).to(device)

if not args.test:
    print("Starting training...")
    best_reward = -np.inf
    for i_episode in trange(args.max_episodes):
        obs = env.reset()
        ep_reward, ep_served, ep_cost = 0, 0, 0
        done = False
        while not done:
            obs = env.Initial_step()
            paxAction, rebAction = model.select_action(obs)
            
            _, reward, done, info, _, _ = env.step(paxAction, rebAction, args.max_steps, None, None, None)
            model.rewards.append(reward)

            ep_reward += reward
            ep_served += info["served_demand"]
            ep_cost += info["rebalancing_cost"]
            

        model.training_step()

        if ep_reward > best_reward:
            model.save_checkpoint()
            best_reward = ep_reward

        print(f"Ep {i_episode+1} | Reward {ep_reward:.1f} | Served {ep_served:.1f} | RebCost {ep_cost:.1f}")

else:
    print("Starting testing...")
    
    # Load pre-trained model
    try:
        model.load_checkpoint(path="ckpt/a2c_gnn.pth")
        print("Loaded pre-trained model from ckpt/a2c_gnn.pth")
    except:
        print("Warning: Could not load pre-trained model, using random initialization")
    
    model.eval()  # Set to evaluation mode
    
    # Test parameters
    test_episodes = 5
    log = {'test_reward': [], 'test_served_demand': [], 'test_reb_cost': []}
    
    print(f"Running {test_episodes} test episodes...")
    
    for episode in range(test_episodes):
        episode_reward = 0
        episode_served_demand = 0
        episode_rebalancing_cost = 0
        obs = env.reset()
        done = False
        
        print(f"\n--- Test Episode {episode+1} ---")
        step_count = 0
        
        while not done:
            obs = env.Initial_step()
            
            # In test mode, don't add to rewards list
            paxAction, rebAction = model.select_action(obs)
            
            _, reward, done, info, _, _ = env.step(paxAction, rebAction, args.max_steps, None, None, None)
            
            episode_reward += reward
            episode_served_demand += info['served_demand']
            episode_rebalancing_cost += info['rebalancing_cost']
            step_count += 1
            

        log['test_reward'].append(episode_reward)
        log['test_served_demand'].append(episode_served_demand)
        log['test_reb_cost'].append(episode_rebalancing_cost)
        
        print(f"Episode {episode+1} Summary: Reward={episode_reward:.1f}, Served={episode_served_demand:.1f}, RebCost={episode_rebalancing_cost:.1f}")
    
    print("\n" + "="*50)
    print("TEST RESULTS SUMMARY")
    print("="*50)
    print(f"Episodes: {test_episodes}")
    print(f"Average Reward: {np.mean(log['test_reward']):.1f} ± {np.std(log['test_reward']):.1f}")
    print(f"Average Served Demand: {np.mean(log['test_served_demand']):.1f} ± {np.std(log['test_served_demand']):.1f}")
    print(f"Average Rebalancing Cost: {np.mean(log['test_reb_cost']):.1f} ± {np.std(log['test_reb_cost']):.1f}")
    print(f"Max Reward: {np.max(log['test_reward']):.1f}")
    print(f"Min Reward: {np.min(log['test_reward']):.1f}")
    
    avg_reward_per_served = np.mean(log['test_reward']) / np.mean(log['test_served_demand'])
    print(f"Reward per Served Passenger: {avg_reward_per_served:.2f}")
    
    test_results = {
        'episodes': int(test_episodes),
        'rewards': [float(x) for x in log['test_reward']],
        'served_demand': [float(x) for x in log['test_served_demand']],
        'rebalancing_cost': [float(x) for x in log['test_reb_cost']],
        'avg_reward': float(np.mean(log['test_reward'])),
        'std_reward': float(np.std(log['test_reward'])),
        'avg_served': float(np.mean(log['test_served_demand'])),
        'avg_reb_cost': float(np.mean(log['test_reb_cost']))
    }
    
    with open(f"test_results_{city}_e2e.json", "w") as f:
        json.dump(test_results, f, indent=2)
    print(f"\nTest results saved to test_results_{city}_e2e.json")
