#!/usr/bin/env python  
#-*- coding:utf-8 _*-
# Filename: data_utils_fix.py
# (最终版) 专为长时序自回归任务设计的数据加载工具
# 适配于优化后的 process_data.py 生成的数据

import os
import torch
import numpy as np
import dgl
import pickle
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from torch.nn.modules.loss import _WeightedLoss
# 假设这些在您的项目中存在
from utils import UnitTransformer, MultipleTensors
from models.cgpt import CGPTNO, StructuredRecursiveGNOT
from models.mmgpt import GNOT, SR_GNOT
from dgl.nn.pytorch import SumPooling, AvgPooling
class MIOSegmentDataset(Dataset):
    """
    直接从 .pkl 文件加载数据，并将其处理成时间序列段的数据集。
    (适配优化后的数据格式: [coords, state_t, theta, None])
    """
    def __init__(self, data_path, name, segment_length, sim_starts, num_samples, normalize_y_config, normalize_x_config):
        super().__init__()
        self.data_path = data_path
        self.name = name
        self.segment_length = segment_length
        
        # 1. 加载扁平化的、每个时间点一个样本的数据
        print(f"Loading flat data from {self.data_path}...")
        flat_data_list = pickle.load(open(self.data_path, "rb"))
        if num_samples != 'all' and isinstance(num_samples, int):
            flat_data_list = flat_data_list[:num_samples]
        print(f"Loaded {len(flat_data_list)} total time steps.")

        # 2. 创建有效的起始索引
        self.valid_start_indices = []
        sim_boundaries = sim_starts + [len(flat_data_list)]
        for i in range(len(sim_boundaries) - 1):
            start_of_sim = sim_boundaries[i]
            end_of_sim = sim_boundaries[i+1]
            for j in range(start_of_sim, end_of_sim - self.segment_length + 1):
                self.valid_start_indices.append(j)
        
        # 3. 预处理和缓存数据
        self._preprocess(flat_data_list)
        
        # 4. 设置和应用归一化
        # 状态Y同时是输入也是输出，所以我们用一个normalizer
        self.y_normalizer = self._setup_normalizer(normalize_y_config, self.all_states)
        self.x_normalizer, self.up_normalizer = self._setup_normalizers_x(normalize_x_config, self.all_coords, self.all_u_p)
        self._normalize_data()

        # 5. 更新数据集配置
        self._update_dataset_config()

        print(f"MIOSegmentDataset initialized. Number of segments: {len(self.valid_start_indices)}")

    def _preprocess(self, flat_data_list):
        """
        将数据列表转换为torch张量列表。
        新格式: sample = [coords, state_t, theta, None]
        """
        # coords 和 theta 对于一个模拟是不变的，但为了通用性，我们还是全部加载
        self.all_coords = [torch.from_numpy(sample[0]).float() for sample in flat_data_list]
        self.all_states = [torch.from_numpy(sample[1]).float() for sample in flat_data_list]
        self.all_u_p = [torch.from_numpy(sample[2]).float() for sample in flat_data_list]
        # self.all_inputs_f (如果需要的话)

        # 创建一个DGL图模板 (假设所有样本的拓扑和节点数都一样)
        num_nodes = self.all_coords[0].shape[0]
        self.g_template = dgl.DGLGraph()
        self.g_template.add_nodes(num_nodes)

    def _setup_normalizer(self, config, data_list):
        if not config.get('use') or not data_list:
            return None
        all_data = torch.cat(data_list, dim=0)
        if config.get('method') == 'unit':
            normalizer = UnitTransformer(all_data)
            print(f"State/Target features normalized using UnitTransformer.")
            return normalizer
        return None
        
    def _setup_normalizers_x(self, config, coord_list, up_list):
        if not config.get('use') or not coord_list:
            return None, None
        coord_all = torch.cat(coord_list, dim=0)
        up_all = torch.stack(up_list, dim=0)
        if config.get('method') == 'unit':
            x_normalizer = UnitTransformer(coord_all)
            up_normalizer = UnitTransformer(up_all)
            print(f"Coordinate and Parameter features normalized using UnitTransformer.")
            return x_normalizer, up_normalizer
        return None, None

    def _normalize_data(self):
        if self.y_normalizer:
            self.all_states = [self.y_normalizer.transform(s) for s in self.all_states]
        if self.x_normalizer:
            self.all_coords = [self.x_normalizer.transform(c) for c in self.all_coords]
            self.all_u_p = [self.up_normalizer.transform(up) for up in self.all_u_p]

    def _update_dataset_config(self):
        self.config = {
            'input_dim': self.all_coords[0].shape[1],   # -> coord_dim
            'theta_dim': self.all_u_p[0].shape[0],      # -> theta_dim
            'output_dim': self.all_states[0].shape[1],  # -> state_dim
            'branch_sizes': [] # 优化后的数据格式不使用 branch input
        }

    def __len__(self):
        return len(self.valid_start_indices)

    def __getitem__(self, idx):
        start_idx = self.valid_start_indices[idx]
        
        # 提取数据段的索引范围
        segment_indices = range(start_idx, start_idx + self.segment_length)
        
        # --- 构建输入和目标序列 ---
        # 输入序列: state_t, state_{t+1}, ..., state_{t+T-2}
        # 目标序列: state_{t+1}, state_{t+2}, ..., state_{t+T-1}
        
        # 输入特征是 t=0 到 t=T-2 的状态
        features_list = [self.all_states[i] for i in segment_indices[:-1]]
        # 目标是 t=1 到 t=T-1 的状态
        targets_list = [self.all_states[i] for i in segment_indices[1:]]

        # 堆叠成张量
        features_seq = torch.stack(features_list, dim=0)
        targets_seq = torch.stack(targets_list, dim=0)
        
        # 静态坐标，u_p等 (假设在一个segment内不变, 取第一个即可)
        coords = self.all_coords[start_idx]
        u_p = self.all_u_p[start_idx]
        
        # 返回的数据格式与collate_fn和训练循环期望的一致
        return self.g_template, coords, u_p, MultipleTensors([]), features_seq, targets_seq

# ==============================================================================
# Collate Function (保持不变)
# ==============================================================================

def collate_segment_batch(batch):
    # try:
    #     if not batch:
    #         print("DEBUG: collate_fn received an empty batch.")
    #         return None # 或者处理空批次的情况

    #     # 打印批次的基本信息
    #     print(f"\n--- DEBUG START: Inside collate_fn ---")
    #     print(f"Batch Type: {type(batch)}")
    #     print(f"Batch Size (number of samples): {len(batch)}")

    #     # 逐个检查批次中每个样本的类型和长度
    #     all_samples_ok = True
    #     for i, sample in enumerate(batch):
    #         sample_len = -1
    #         try:
    #             sample_len = len(sample)
    #         except Exception:
    #             pass # 如果样本没有长度，就忽略

    #         print(f"  Sample {i}: Type={type(sample)}, Length={sample_len}")
    #         if sample_len != 6:
    #             all_samples_ok = False
    #             print(f"    !!!! ERROR: Sample {i} has length {sample_len}, but expected 6 !!!!")
    #             # 打印出这个异常样本的内容，看看它到底是什么
    #             print(f"    !!!! Sample {i} Content: {sample}")

    #     if not all_samples_ok:
    #          # 如果发现问题，在解包前就主动抛出异常，提供清晰信息
    #         raise ValueError("Mismatch in sample length detected within the batch. See DEBUG log above.")
        
    #     print("--- DEBUG END: All samples in batch seem OK. Proceeding to unpack. ---")

        # 原始的解包代码
    g_templates, coords, u_ps, inputs_fs, features_seqs, targets_seqs = zip(*batch)
    
    # except Exception as e:
    #     # 如果在调试或解包过程中发生任何错误，打印出完整的 batch 内容
    #     print("\n--- FATAL ERROR in collate_fn ---")
    #     print(f"Caught Exception: {e}")
    #     print("Dumping the entire batch object for inspection:")
    #     print(batch)
    #     print("------------------------------------")
    #     # 重新抛出异常，让程序停止
    #     raise e
    # # --- 调试代码结束 ---

    # 原始的后续代码
    g_batch = dgl.batch(g_templates)
    coords_batch = torch.stack(coords, dim=0)
    u_p_batch = torch.stack(u_ps, dim=0)
    features_batch = torch.stack(features_seqs, dim=0)
    targets_batch = torch.stack(targets_seqs, dim=0)
    
    inputs_f_batch = MultipleTensors([])
        
    return g_batch, coords_batch, u_p_batch, inputs_f_batch, features_batch, targets_batch
# ==============================================================================
# 主数据加载函数
# ==============================================================================

def get_segment_dataset(args):
    """
    主函数，用于创建训练和测试的分段数据集。
    """
    if args.dataset == "ns2d_autoregressive":
        train_path = './data/ns_train.pkl' 
        test_path = './data/ns_test.pkl'
        # timesteps_per_sim 应从您的数据中得知
        timesteps_per_sim = 20 
        # 假设训练集包含1000个模拟 (50000个时间点)，测试集100个(5000个时间点)
        num_train_sims = getattr(args, 'num_train_sims', 500)
        num_test_sims = getattr(args, 'num_test_sims', 100) # 假设
        sim_starts_train = [i * timesteps_per_sim for i in range(num_train_sims)] 
        sim_starts_test = [i * timesteps_per_sim for i in range(num_test_sims)]
    elif args.dataset == "kf2d":
        train_path = './data/kf_train.pkl' 
        test_path = './data/kf_test.pkl'
        # timesteps_per_sim 应从您的数据中得知
        timesteps_per_sim = 100 
        # 假设训练集包含1000个模拟 (50000个时间点)，测试集100个(5000个时间点)
        num_train_sims = getattr(args, 'num_train_sims', 100)
        num_test_sims = getattr(args, 'num_test_sims', 20) # 假设
        sim_starts_train = [i * timesteps_per_sim for i in range(num_train_sims)] 
        sim_starts_test = [i * timesteps_per_sim for i in range(num_test_sims)]
     
    else:
        raise NotImplementedError(f"Dataset '{args.dataset}' not configured for segmentation.")

    normalize_y_config = {'use': args.use_normalizer, 'method': 'unit'}
    normalize_x_config = {'use': args.normalize_x, 'method': 'unit'}

    train_dataset = MIOSegmentDataset(
        data_path=train_path,
        name=args.dataset,
        segment_length=args.segment_length,
        sim_starts=sim_starts_train,
        num_samples=args.train_num, # train_num 现在指总时间步数
        normalize_y_config=normalize_y_config,
        normalize_x_config=normalize_x_config
    )
    
    # 共享训练集的normalizer
    test_dataset = MIOSegmentDataset(
        data_path=test_path,
        name=args.dataset,
        segment_length=args.segment_length,
        sim_starts=sim_starts_test,
        num_samples=args.test_num,
        normalize_y_config={'use': False}, # 测试集不再计算normalizer
        normalize_x_config={'use': False}
    )
    test_dataset.y_normalizer = train_dataset.y_normalizer
    test_dataset.x_normalizer = train_dataset.x_normalizer
    test_dataset.up_normalizer = train_dataset.up_normalizer
    test_dataset._normalize_data() # 手动应用训练集的normalizer

    args.y_normalizer = train_dataset.y_normalizer
    args.x_normalizer = train_dataset.x_normalizer
    args.up_normalizer = train_dataset.up_normalizer
    args.dataset_config = train_dataset.config
    
    return train_dataset, test_dataset
# ==============================================================================
# Model Instantiation
# ==============================================================================
def get_model(args):
    config = args.dataset_config
    if config is None:
        raise ValueError("args.dataset_config is not set. Load dataset before instantiating model.")

    coord_dim = config['input_dim']
    state_dim = config['output_dim']
    theta_dim = config['theta_dim']
    branch_sizes = config['branch_sizes']
    output_size = config['output_dim']

    print("\n--- Model Configuration ---")
    print(f"Model Name: {args.model_name}")
    print(f"Coordinate Dimension (coord_dim): {coord_dim}")
    print(f"State/Solution Dimension (state_dim): {state_dim}")
    print(f"Parameter Dimension (theta_dim): {theta_dim}")
    print(f"Output Dimension: {output_size}")
    print("-" * 27)

    model_kwargs = {
        'coord_dim': coord_dim, 'state_dim': state_dim, 'theta_dim': theta_dim,
        'branch_sizes': branch_sizes, 'output_size': output_size,
        'n_hidden': args.n_hidden, 'n_layers': args.n_layers, 'n_head': args.n_head,
        'n_inner': args.n_inner, 'mlp_layers': args.mlp_layers,
        'attn_type': args.attn_type, 'act': args.act,
        'ffn_dropout': args.ffn_dropout, 'attn_dropout': args.attn_dropout
    }

    if args.model_name == "GNOT":
        model_kwargs.update({'n_experts': args.n_experts, 'space_dim': args.space_dim})
        return GNOT(**model_kwargs)
    elif args.model_name == "CGPT":
        return CGPTNO(**model_kwargs)
    elif args.model_name == "SR_GNOT":
        model_kwargs.update({
            'space_dim': args.space_dim, 'n_experts': args.n_experts,
            'capacity_schedule': args.capacity_schedule,
            'final_keep_ratio': args.final_keep_ratio
        })
        return SR_GNOT(**model_kwargs)
    elif args.model_name == "StructuredRecursiveGNOT":
        model_kwargs.update({
            'num_fine_nodes': getattr(args, 'num_fine_nodes', 4096),
            'num_coarse_nodes': getattr(args, 'num_coarse_nodes', 256),
            'final_keep_ratio': args.final_keep_ratio
        })
        return StructuredRecursiveGNOT(**model_kwargs)
    else:
        raise NotImplementedError(f"Model '{args.model_name}' is not implemented in get_model.")

# ==============================================================================
# Loss Function
# ==============================================================================
class WeightedLpRelLoss(_WeightedLoss):
    def __init__(self, p=2, component='all-reduce', normalizer=None):
        super(WeightedLpRelLoss, self).__init__()
        self.p = p
        self.component = component
        self.normalizer = normalizer
        self.sum_pool = SumPooling()

    def _lp_losses(self, g, pred, target):
        err_pool = self.sum_pool(g, (pred - target).abs() ** self.p)
        target_pool = self.sum_pool(g, target.abs() ** self.p)
        losses = (err_pool / (target_pool + 1e-8))**(1 / self.p)
        return losses

    def forward(self, g, pred, target):
        if self.normalizer is not None:
            pred = self.normalizer.transform(pred, inverse=True)
            target = self.normalizer.transform(target, inverse=True)
        
        if self.component == 'all':
            losses = self._lp_losses(g, pred, target)
            metrics = losses.mean(dim=0).clone().detach().cpu().numpy()
        else:
            if self.component != 'all-reduce':
                c = int(self.component)
                pred, target = pred[:, c], target[:, c]
            losses = self._lp_losses(g, pred, target)
            metrics = losses.mean().clone().detach().cpu().numpy()

        loss = losses.mean()
        reg = torch.zeros_like(loss)
        return loss, reg, metrics
def get_loss_func(name, args, **kwargs):
    if name == 'rel2':
        return WeightedLpRelLoss(p=2,component=args.component, normalizer=kwargs['normalizer'])
    elif name == "rel1":
        return WeightedLpRelLoss(p=1,component=args.component, normalizer=kwargs['normalizer'])
    elif name == 'l2':
        return WeightedLpLoss(p=2, component=args.component, normalizer=kwargs["normalizer" ])
    elif name == "l1":
        return WeightedLpLoss(p=1, component=args.component, normalizer=kwargs["normalizer" ])
    else:
        raise NotImplementedError