import os
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
from sklearn.preprocessing import StandardScaler
from utils.timefeatures import time_features
import warnings

warnings.filterwarnings('ignore')

class Dataset_CGMacros_CrossSubject(Dataset):
    def __init__(self, root_path, flag='train', size=None,
                 features='MS', target='Dexcom GL', scale=True, timeenc=0, freq='min'):
        """
        跨受试者CGMacros数据集，用于时间序列预测
        :param root_path: 数据目录（包含6个病人的csv）
        :param flag: 'train' / 'val' / 'test'
        :param size: [seq_len, label_len, pred_len]
        :param features: 'S'（单特征），'M'（多特征），'MS'（多特征预测单目标）
        :param target: 预测目标列名
        :param scale: 是否标准化
        :param timeenc: 时间编码模式（0: 手动; 1: sin/cos编码）
        :param freq: 采样频率（'min' 表示分钟级）
        """
        if size is None:
            self.seq_len = 240
            self.label_len = 120
            self.pred_len = 60
        else:
            self.seq_len, self.label_len, self.pred_len = size

        assert flag in ['train', 'val', 'test']
        type_map = {'train': 0, 'val': 1, 'test': 2}
        self.set_type = type_map[flag]

        self.features = features
        self.target = target
        self.scale = scale
        self.timeenc = timeenc
        self.freq = freq
        self.root_path = root_path

        self.__read_data__()

    def __read_data__(self):
        # === 1️⃣ 按受试者划分数据 ===
        all_files = sorted([f for f in os.listdir(self.root_path) if f.endswith('.csv')])
        train_files = all_files[:4]
        val_files = all_files[4:5]
        test_files = all_files[5:]

        file_split = {0: train_files, 1: val_files, 2: test_files}
        selected_files = file_split[self.set_type]

        dfs = []
        for f in selected_files:
            df = pd.read_csv(os.path.join(self.root_path, f))
            df['Timestamp'] = pd.to_datetime(df['Timestamp'])
            df = df.sort_values('Timestamp')
            dfs.append(df)
            # ✅ 打印每个受试者 Dexcom GL 的统计信息
        print("\n==================== CGMacros 数据集统计 ====================")
        for f in all_files:
            path = os.path.join(self.root_path, f)
            d = pd.read_csv(path)
            mean_gl = d['Dexcom GL'].mean()
            std_gl = d['Dexcom GL'].std()
            print(f"{f:<25} mean={mean_gl:.2f}, std={std_gl:.2f}")
        print("============================================================\n")
        df_raw = pd.concat(dfs, axis=0, ignore_index=True)

        # === 2️⃣ 选择特征列 ===
        cols = list(df_raw.columns)
        cols.remove(self.target)
        cols.remove('Timestamp')
        df_raw = df_raw[['Timestamp'] + cols + [self.target]]

        # === 3️⃣ 归一化 ===
        self.scaler = StandardScaler()
        if self.scale:
            # === 训练集 ===
            if self.set_type == 0:
                train_df = pd.concat(
                    [pd.read_csv(os.path.join(self.root_path, f)) for f in file_split[0]],
                    ignore_index=True
                )
                train_cols = [c for c in train_df.columns if c not in ['Timestamp', self.target]]
                self.scaler.fit(train_df[train_cols])
                self.target_scaler = StandardScaler()
                self.target_scaler.fit(train_df[[self.target]])
        
            # === 验证集或测试集 ===
            else:
                # 重新加载训练集的统计特征
                train_df = pd.concat(
                    [pd.read_csv(os.path.join(self.root_path, f)) for f in file_split[0]],
                    ignore_index=True
                )
                train_cols = [c for c in train_df.columns if c not in ['Timestamp', self.target]]
                self.scaler.fit(train_df[train_cols])
                self.target_scaler = StandardScaler()
                self.target_scaler.fit(train_df[[self.target]])
        
            data = self.scaler.transform(df_raw[cols])
            target_scaled = self.target_scaler.transform(df_raw[[self.target]])


        # === 4️⃣ 时间特征 ===
        df_stamp = df_raw[['Timestamp']]
        df_stamp['Timestamp'] = pd.to_datetime(df_stamp['Timestamp'])
        if self.timeenc == 0:
            df_stamp['month'] = df_stamp.Timestamp.apply(lambda row: row.month, 1)
            df_stamp['day'] = df_stamp.Timestamp.apply(lambda row: row.day, 1)
            df_stamp['weekday'] = df_stamp.Timestamp.apply(lambda row: row.weekday(), 1)
            df_stamp['hour'] = df_stamp.Timestamp.apply(lambda row: row.hour, 1)
            df_stamp['minute'] = df_stamp.Timestamp.apply(lambda row: row.minute, 1)
            # df_stamp['minute'] = df_stamp['minute'].map(lambda x: x // 5)  # 每5分钟分段
            data_stamp = df_stamp.drop(['Timestamp'], axis=1).values
        else:
            data_stamp = time_features(pd.to_datetime(df_stamp['Timestamp'].values), freq=self.freq)
            data_stamp = data_stamp.transpose(1, 0)

        self.data_x = data
        self.data_y = target_scaled
        self.data_stamp = data_stamp

    def __getitem__(self, index):
        s_begin = index
        s_end = s_begin + self.seq_len
        r_begin = s_end - self.label_len
        r_end = r_begin + self.label_len + self.pred_len

        seq_x = self.data_x[s_begin:s_end]
        seq_y = self.data_y[r_begin:r_end]
        seq_x_mark = self.data_stamp[s_begin:s_end]
        seq_y_mark = self.data_stamp[r_begin:r_end]

        return seq_x, seq_y, seq_x_mark, seq_y_mark

    def __len__(self):
        return len(self.data_x) - self.seq_len - self.pred_len + 1

    def inverse_transform(self, data):
        return self.target_scaler.inverse_transform(data)
