"""
经验数据收集使用示例 - 严格按照用户要求的格式
"""

import torch
from app.simulation.Simulation import Simulation
from app.network.Network import Network
import json
import datetime

def main():
    """
    主函数：演示如何收集和使用车辆经验数据
    严格按照格式：s(上一时刻路段,排队长度), a(上一时刻将要进入的路段id), r(上一时刻将要进入的路段权重), s'(当前路段,排队长度)
    """
    # 初始化仿真
    Simulation.start()
    
    # 模拟一些权重数据 (路段权重)
    num_edges = len(Simulation.edge_ids)
    model_weights = torch.rand(num_edges)  # 随机权重
    
    # 创建车辆ID映射 (示例)
    car_id_map = {f"car_{i}": i for i in range(100)}  # 假设有100辆车
    
    # 运行仿真并收集数据
    print("开始收集车辆经验数据...")
    print("数据格式: s(上一时刻路段,排队长度), a(上一时刻将要进入的路段id), r(上一时刻将要进入的路段权重), s'(当前路段,排队长度)")
    
    # 模拟运行多个时间步
    for tick in range(100):  # 运行100个时间步
        # 运行一步仿真
        vehicle_data = Simulation.loop(model_weights, car_id_map)
        
        # 每10步打印一次统计信息
        if tick % 10 == 0:
            print(f"时间步 {tick}: 收集到 {len(Simulation.get_experience_data())} 辆车的经验数据")
    
    # 获取所有经验数据
    all_experiences = Simulation.get_experience_data()
    
    print(f"\n总共收集到 {len(all_experiences)} 辆车的经验数据")
    
    # 打印每辆车的经验数据统计
    for car_id, experiences in all_experiences.items():
        print(f"车辆 {car_id}: {len(experiences)} 条经验")
        
        # 打印前几条经验数据作为示例
        if experiences:
            print(f"  示例经验数据:")
            for i, exp in enumerate(experiences[:3]):  # 只显示前3条
                print(f"    经验 {i+1}:")
                print(f"      s_prev: {exp['s_prev']} (上一时刻路段, 上一时刻排队长度)")
                print(f"      a_prev: {exp['a_prev']} (上一时刻将要进入的路段ID)")
                print(f"      r_prev: {exp['r_prev']} (上一时刻将要进入的路段权重)")
                print(f"      s_current: {exp['s_current']} (当前路段, 当前排队长度)")
                print()
    
    # 保存经验数据到文件
    timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
    filename = f'./app/data/vehicle_experiences_{timestamp}.json'
    Simulation.save_experience_data(filename)
    print(f"经验数据已保存到: {filename}")
    
    # 示例：获取特定车辆的经验数据
    if all_experiences:
        first_car_id = list(all_experiences.keys())[0]
        car_experiences = Simulation.get_vehicle_experience(first_car_id)
        print(f"\n车辆 {first_car_id} 的经验数据格式:")
        if car_experiences:
            exp = car_experiences[0]
            print(f"  s_prev: {exp['s_prev']} (上一时刻路段, 上一时刻排队长度)")
            print(f"  a_prev: {exp['a_prev']} (上一时刻将要进入的路段ID)")
            print(f"  r_prev: {exp['r_prev']} (上一时刻将要进入的路段权重)")
            print(f"  s_current: {exp['s_current']} (当前路段, 当前排队长度)")
    
    # 示例：构建强化学习训练数据
    print("\n构建强化学习训练数据:")
    training_data = []
    
    for car_id, experiences in all_experiences.items():
        car_data = []
        for exp in experiences:
            # 构建训练样本: (s_{t-1}, a_{t-1}, r, s_t)
            training_sample = {
                'state_prev': exp['s_prev'],      # (上一时刻路段, 上一时刻排队长度)
                'action_prev': exp['a_prev'],     # 上一时刻将要进入的路段ID
                'reward_prev': exp['r_prev'],     # 上一时刻将要进入的路段权重
                'state_current': exp['s_current'], # (当前路段, 当前排队长度)
                'car_id': car_id
            }
            car_data.append(training_sample)
        
        training_data.append({
            'car_id': car_id,
            'experiences': car_data
        })
    
    print(f"构建了 {len(training_data)} 个车辆的训练数据集")
    
    # 保存训练数据
    training_filename = f'./app/data/training_data_{timestamp}.json'
    with open(training_filename, 'w', encoding='utf-8') as f:
        json.dump(training_data, f, ensure_ascii=False, indent=2)
    print(f"训练数据已保存到: {training_filename}")
    
    # 打印数据格式说明
    print("\n数据格式说明:")
    print("每条经验数据包含:")
    print("  s_prev: (上一时刻路段ID, 上一时刻排队长度)")
    print("  a_prev: 上一时刻将要进入的路段ID")
    print("  r_prev: 上一时刻将要进入的路段权重")
    print("  s_current: (当前路段ID, 当前排队长度)")
    print("  timestamp: 时间戳")

def analyze_experience_data():
    """
    分析经验数据的函数
    """
    # 加载保存的经验数据
    with open('./app/data/vehicle_experiences_20241201_120000.json', 'r', encoding='utf-8') as f:
        experiences = json.load(f)
    
    print("经验数据分析:")
    print(f"总车辆数: {len(experiences)}")
    
    total_experiences = 0
    for car_id, car_experiences in experiences.items():
        total_experiences += len(car_experiences)
        print(f"车辆 {car_id}: {len(car_experiences)} 条经验")
    
    print(f"总经验数: {total_experiences}")
    
    # 分析权重分布
    all_weights = []
    for car_experiences in experiences.values():
        for exp in car_experiences:
            all_weights.append(exp['r_prev'])
    
    if all_weights:
        print(f"权重统计:")
        print(f"  平均权重: {sum(all_weights) / len(all_weights):.4f}")
        print(f"  最大权重: {max(all_weights):.4f}")
        print(f"  最小权重: {min(all_weights):.4f}")

def print_data_format_example():
    """
    打印数据格式示例
    """
    print("经验数据格式示例:")
    print("=" * 50)
    
    example_experience = {
        's_prev': ('E1', 5),           # (上一时刻路段, 上一时刻排队长度)
        'a_prev': 'E2',                # 上一时刻将要进入的路段ID
        'r_prev': 0.8,                 # 上一时刻将要进入的路段权重
        's_current': ('E2', 3),        # (当前路段, 当前排队长度)
        'timestamp': 1640995200.123    # 时间戳
    }
    
    print("单条经验数据:")
    for key, value in example_experience.items():
        if key == 's_prev':
            print(f"  {key}: {value} (上一时刻路段, 上一时刻排队长度)")
        elif key == 'a_prev':
            print(f"  {key}: {value} (上一时刻将要进入的路段ID)")
        elif key == 'r_prev':
            print(f"  {key}: {value} (上一时刻将要进入的路段权重)")
        elif key == 's_current':
            print(f"  {key}: {value} (当前路段, 当前排队长度)")
        else:
            print(f"  {key}: {value}")
    
    print("\n完整数据结构:")
    print("{" + f"'car_1': [{example_experience}, ...], 'car_2': [...], ..." + "}")

if __name__ == "__main__":
    main()
    # analyze_experience_data()
    # print_data_format_example() 