import numpy as np
import pandas as pd
import torch
import random
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.model_selection import train_test_split
from torch import nn
from torch.nn.functional import mse_loss, l1_loss
from model import SpatioTemporalTransformer
# from model_no_phy import SpatioTemporalTransformer
import logging
from datetime import datetime
from torch.cuda.amp import GradScaler, autocast
from scipy.interpolate import griddata
from torch.optim import AdamW
from numpy.lib.stride_tricks import sliding_window_view
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts


from torch.optim.lr_scheduler import LambdaLR
def set_seed(seed=123):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # 多GPU用

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def warmup_scheduler(optimizer, warmup_steps=1000):
    def lr_lambda(step):
        if step < warmup_steps:
            return float(step) / float(max(1, warmup_steps))
        return 1.0  # Warm-up结束后恢复正常学习率
    return LambdaLR(optimizer, lr_lambda)

# EarlyStopping class implementation
class EarlyStopping:
    def __init__(self, patience=5, verbose=False, restore_best_weights=True):
        self.patience = patience
        self.verbose = verbose
        self.restore_best_weights = restore_best_weights
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.best_weights = None

    def __call__(self, val_loss, model, epoch=None):
        score = -val_loss
        if self.best_score is None:
            self.best_score = score
            self.save_best_weights(model)
        elif score < self.best_score:
            self.counter += 1
            if self.verbose and epoch is not None:
                print(f"Epoch {epoch + 1}: EarlyStopping counter: {self.counter} out of {self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_best_weights(model)
            self.counter = 0

    def save_best_weights(self, model):
        if self.restore_best_weights:
            self.best_weights = model.state_dict()

    def load_best_weights(self, model):
        if self.restore_best_weights:
            model.load_state_dict(self.best_weights)

# 修改1：新的风向转换逻辑

def preprocess_df(df):
    # 将原始风向角度转换为弧度（气象学角度转换为数学极坐标）
    wind_direction_deg = (270 - df["wd"]) % 360  # 气象角度转数学极坐标
    wind_direction_rad = np.radians(wind_direction_deg)

    # 计算东向和北向风分量（需确保存在WSPM列）
    df["u"] = df["WSPM"] * np.cos(wind_direction_rad)  # 东向分量
    df["v"] = df["WSPM"] * np.sin(wind_direction_rad)  # 北向分量
    df['TEMP_K'] = df['TEMP'] + T_0

    df['hour'] = df['time'].str[11:13].astype(int)
    df['hour_sin'] = np.sin(2 * np.pi * df['hour'] / 24)
    df['hour_cos'] = np.cos(2 * np.pi * df['hour'] / 24)
    df['pm2.5_mean_24h'] = df['PM2.5'].rolling(24).mean().fillna(method='bfill')
    df['pm2.5_max_24h'] = df['PM2.5'].rolling(24).max().fillna(method='bfill')

    # 交互特征
    df['TEMP_PM10']     = df['TEMP'] * df['PM10']
    df['o3_PM10']       = df['o3'] * df['PM10']
    df['u_v']          = df['u'] * df['v']
    df['o3_TEMP']       = df['o3'] * df['TEMP']

    # 新增差分特征
    df['PM2.5_diff_1h'] = df['PM2.5'].diff(1).fillna(0)
    df['TEMP_delta_6h'] = df['TEMP'] - df['TEMP'].shift(6).fillna(method='bfill')

    # 滚动窗口特征（可增强模型记忆能力）
    # df['TEMP_mean_12h'] = df['TEMP'].rolling(12).mean().fillna(method='bfill')
    # df['TEMP_std_24h'] = df['TEMP'].rolling(24).std().fillna(method='bfill')
    #
    # df['O3_mean_12h'] = df['O3'].rolling(12).mean().fillna(method='bfill')
    # df['O3_std_24h'] = df['O3'].rolling(24).std().fillna(method='bfill')
    #
    # df['wind_speed'] = np.sqrt(df['u'] ** 2 + df['v'] ** 2)
    # df['wind_mean_12h'] = df['wind_speed'].rolling(12).mean().fillna(method='bfill')
    # df['wind_std_24h'] = df['wind_speed'].rolling(24).std().fillna(method='bfill')
    return df

# 修改2：时间处理函数
def process_time(df):
    # 使用正确的格式：%Y-%m-%d 对应连字符分隔的日期
    df['datetime'] = pd.to_datetime(df['time'], format='%Y-%m-%d %H:%M:%S')
    df = df.sort_values('datetime').set_index('datetime')
    return df

class AirQualityDataset(Dataset):
    def __init__(self, X, y, location_idx, df, sequence_length):
        self.X = X
        self.y = y
        self.location_idx = location_idx
        self.df = df
        self.sequence_length = sequence_length

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return (
            self.X[idx],
            self.y[idx],
            self.location_idx[idx],
            # 确保返回的 temperature 是张量且为 1D
            torch.tensor(
                self.df['TEMP_K'].iloc[idx:idx + self.sequence_length].values.flatten(),
                dtype=torch.float32
            )
        )

def custom_collate(batch):
    inputs, targets, location_idx, temp_k = [], [], [], []
    for item in batch:
        inputs.append(item[0])
        targets.append(item[1])
        location_idx.append(item[2])
        temp_k.append(item[3])  # item[3] 已经是 1D 张量
        # temp_k = [item[3] for item in batch]
        # temp_k = torch.stack(temp_k).permute(1, 0)  # [seq, batch]


    return (
        torch.stack(inputs),  # [batch, seq, features]
        torch.stack(targets),  # [batch, horizon]
        torch.stack(location_idx),  # [batch, horizon]
        torch.stack(temp_k)  # [batch, seq]

    )

if __name__ == '__main__':
    # 设置随机种子
    # torch.manual_seed(42)
    # np.random.seed(42)
    # set_seed(123)
    set_seed(42)

    # 检查 GPU 可用性
    device = torch.device("cuda")
    # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    import logging
    from datetime import datetime

    sequence_length = 24
    horizon = 72

    # 获取当前时间
    start_time = datetime.now().strftime("%Y%m%d_%H%M%S")

    # 正确设置日志文件名（使用 f-string 或传统格式化均可）
    log_filename = f'traininguk_nope_{start_time}_model8-in{sequence_length}-out{horizon}.log'

    # 配置日志记录器
    logging.basicConfig(
        filename=log_filename,
        level=logging.INFO,
        format='%(asctime)s:%(levelname)s:%(message)s'
    )

    # 修改3：数据加载和预处理
    df = pd.read_csv('UK.csv', usecols=[
        'time', 'PM2.5', 'TEMP', 'PM10', 'o3',
        'WSPM', 'wd', 'station', 'latitude', 'longitude'
    ])
    D_0 = 0.2  # 标准扩散系数（m^2/s）
    n = 0.5  # 扩散指数
    T_0 = 273.15  # K，参考温度
    df = df.ffill()  # 用前一个值填充缺失值

    # 处理时间列
    df = process_time(df)
    df = preprocess_df(df)

    # 检查时间连续性
    print("时间缺失情况:", df.index.to_series().diff().value_counts())

    # 修改4：更新特征列表
    features = ['TEMP', 'PM10', 'o3', 'latitude', 'longitude', 'TEMP_PM10', 'o3_PM10', 'u_v', 'o3_TEMP',
                'hour_sin', 'hour_cos', 'u', 'v', 'pm2.5_mean_24h', 'pm2.5_max_24h', 'PM2.5_diff_1h', 'TEMP_delta_6h']
    # features += [
    #     'TEMP_mean_12h', 'TEMP_std_24h',
    #     'O3_mean_12h', 'O3_std_24h',
    #     'wind_mean_12h', 'wind_std_24h'
    # ]
    target = 'PM2.5'

    # 站点映射（保持不变）
    unique_stations = df['station'].unique()
    location_map = {station: idx for idx, station in enumerate(unique_stations)}
    df['location_idx'] = df['station'].map(location_map)

    # 只保留前1000条数据做测试
    # df = df.iloc[:1000]

    scaler = StandardScaler()
    # scaler = MinMaxScaler()
    df[features] = scaler.fit_transform(df[features])

    # 目标值归一化
    target_scaler = StandardScaler()
    # target_scaler = MinMaxScaler()
    df[target] = target_scaler.fit_transform(df[[target]])

    # 修改7：时空网格参数计算（使用实际经纬度）
    H, W = 63, 64  # 与（n_locations, d_model）保持一致
    lat_min, lat_max = df['latitude'].min(), df['latitude'].max()
    lon_min, lon_max = df['longitude'].min(), df['longitude'].max()
    lat_step = (lat_max - lat_min) / (H - 1)
    lon_step = (lon_max - lon_min) / (W - 1)

    # 生成空间索引
    lat_idx = ((df['latitude'] - lat_min) / lat_step).round().astype(int).clip(0, H - 1)
    lon_idx = ((df['longitude'] - lon_min) / lon_step).round().astype(int).clip(0, W - 1)

    # 将numpy数组转为float32
    u_values = df['u'].values.astype(np.float32)
    v_values = df['v'].values.astype(np.float32)

    # 修改8：优化velocity_field生成
    # 创建velocity_field时指定dtype
    velocity_field = torch.zeros((1, 2, H, W),
                                 dtype=torch.float32,
                                 device=device)

    # 使用正确类型的张量进行赋值
    velocity_field[0, 0, lat_idx, lon_idx] = torch.tensor(u_values, device=device)
    velocity_field[0, 1, lat_idx, lon_idx] = torch.tensor(v_values, device=device)

    # 修改9：数据序列生成（保持时间顺序）


    # 1. 全量提取为 ndarray
    arr_X = df[features].values  # (N, F)
    arr_y = df[target].values  # (N,)
    arr_loc = df['location_idx'].values  # (N,)
    arr_temp = df['TEMP_K'].values  # (N,)

    N = len(df)
    S = sequence_length
    H = horizon
    M = N - S - H + 1  # 有效样本数

    # 2. sliding window 视图
    X_windows = sliding_window_view(arr_X, window_shape=S, axis=0)  # (N-S+1, S, F)
    y_windows = sliding_window_view(arr_y, window_shape=H, axis=0)  # (N-H+1, H)
    loc_windows = sliding_window_view(arr_loc, window_shape=H, axis=0)  # (N-H+1, H)
    temp_windows = sliding_window_view(arr_temp, window_shape=S, axis=0)  # (N-S+1, S)

    # 3. 只取前 M 条对齐
    X_np = X_windows[:M]  # (M, S, F)
    print(">>> raw X_np.shape:", X_np.shape)
    # 如果第二维是特征数（==len(features)），说明你现在是 (M, F, S) 而不是 (M, S, F)
    if X_np.shape[1] == len(features):
        # 把中间（F）和最后（S）轴对调
        X_np = X_np.transpose(0, 2, 1)
        print(">>> transposed X_np.shape:", X_np.shape)
    y_np = y_windows[S:S + M]  # (M, H)
    loc_np = loc_windows[S:S + M]  # (M, H)
    temp_np = temp_windows[:M]  # (M, S)

    # 4. 转 PyTorch，送进 Dataset / DataLoader
    X = torch.from_numpy(X_np).float()
    y = torch.from_numpy(y_np).float()
    locations = torch.from_numpy(loc_np).long()
    temp_k_seq = torch.from_numpy(temp_np).float()

    # 修改10：时间序列数据分割（禁止打乱）
    test_size = int(0.2 * len(X))
    val_size = int(0.1 * len(X))

    X_train, X_test = X[:-test_size], X[-test_size:]
    y_train, y_test = y[:-test_size], y[-test_size:]
    loc_train, loc_test = locations[:-test_size], locations[-test_size:]

    X_train, X_val = X_train[:-val_size], X_train[-val_size:]
    y_train, y_val = y_train[:-val_size], y_train[-val_size:]
    loc_train, loc_val = loc_train[:-val_size], loc_train[-val_size:]

    train_dataset = AirQualityDataset(X_train, y_train, loc_train, df.iloc[:len(X_train)+sequence_length], sequence_length)
    val_dataset = AirQualityDataset(X_val, y_val, loc_val, df.iloc[len(X_train):len(X_train)+len(X_val)+sequence_length], sequence_length)
    test_dataset = AirQualityDataset(X_test, y_test, loc_test, df.iloc[len(X_train)+len(X_val):], sequence_length)

    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, drop_last=True, num_workers=4, collate_fn=custom_collate)
    val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=4, collate_fn=custom_collate)
    test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=4, collate_fn=custom_collate)

    # 模型定义
    model = SpatioTemporalTransformer(
        input_dim=len(features),  # 调整输入维度
        d_model=64,
        n_heads=8,
        num_encoder_layers=3,
        num_decoder_layers=2,
        sequence_length=sequence_length,
        n_locations=len(unique_stations),
        horizon=horizon  # 新增参数
    )
    model = model.to(device)

    # 定义优化器和损失函数
    # optimizer = optim.Adam(model.parameters(), lr=1e-4) # -4
    # optimizer = optim.Adam(model.parameters(), lr=1e-4)  # 初始学习率较小
    optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4) # 原始学习率
    # optimizer = AdamW(model.parameters(), lr=1e-3, weight_decay=1e-3)
    warmup_steps = 1000  # 可调整
    # scheduler_warmup = warmup_scheduler(optimizer, warmup_steps)
    # scheduler_plateau = CosineAnnealingLR(optimizer, T_max=15)
    scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=5, T_mult=2)
    # 每个 epoch 调用：
    # scheduler_plateau.step()
    # def criterion(pred, target):
    #     return 0.7 * nn.MSELoss()(pred, target) + 0.3 * nn.L1Loss()(pred, target)


    # criterion = nn.MSELoss()
    criterion = nn.SmoothL1Loss()
    # criterion = nn.L1Loss()
    loss_mse = nn.MSELoss()
    loss_mae = nn.L1Loss()
    # scheduler = CosineAnnealingLR(optimizer, T_max=15)
    scaler1 = GradScaler()

    # EarlyStopping
    early_stopping = EarlyStopping(patience=7, verbose=True)

    num_epochs = 50
    logging.info("Training Start Time: {}".format(start_time))
    alpha = 0.5

    def train_one_epoch(model, train_loader):
        model.train()
        running_loss, running_mae, running_mse = 0.0, 0.0, 0.0

        for batch_idx, (inputs, targets, location_idx, temp_k) in enumerate(train_loader):
            inputs = inputs.to(device)  # [batch_size, seq_len, features]
            targets = targets.to(device)  # [batch_size, horizon]
            location_idx = location_idx.to(device)  # [batch_size, horizon]
            temp_k = temp_k.to(device)  # [batch, seq]

            optimizer.zero_grad()

            # # 调整输入形状
            inputs = inputs.permute(1, 0, 2)  # [seq_len, batch_size, features]
            # print(">> Batch inputs shape before projection:", inputs.shape)
            # inputs = inputs.unsqueeze(2).expand(-1, -1, len(unique_stations),
            #                                     -1)  # [seq_len, batch_size, n_locations, features]

            # 扩展 location_idx 到每个时间步
            location_idx = location_idx.permute(1, 0)  # [horizon, batch_size]
            # location_idx = location_idx.unsqueeze(0).expand(inputs.size(0), -1, -1)  # [seq_len, horizon, batch_size]

            try:
                with autocast():
                    outputs = model(inputs, location_idx=location_idx, temp_k=temp_k, velocity_field=velocity_field)
                    # 直接输出 [batch_size, horizon] 无需调整
                    loss = criterion(outputs, targets)
                    scaler1.scale(loss).backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                    scaler1.step(optimizer)
                    scaler1.update()
                    scheduler.step()

                    running_loss += loss.item()
                    # 反归一化计算指标
                    train_outputs_unscaled = target_scaler.inverse_transform(outputs.detach().cpu().numpy())
                    train_targets_unscaled = target_scaler.inverse_transform(targets.detach().cpu().numpy())
                    mae = np.mean(np.abs(train_outputs_unscaled - train_targets_unscaled))
                    mse = np.mean((train_outputs_unscaled - train_targets_unscaled)  **  2)
                    running_mae += mae
                    running_mse += mse
            except Exception as e:
                print(f"Error in Epoch {epoch}, Batch {batch_idx}: {e}")
                break

            del inputs, targets, outputs, loss
            torch.cuda.empty_cache()

        return running_loss / len(train_loader), running_mae / len(train_loader), (
                    running_mse / len(train_loader)) ** 0.5


    def validate_one_epoch(model, val_loader):
        model.eval()
        val_loss, val_mae, val_rmse = 0.0, 0.0, 0.0

        with torch.no_grad():
            for inputs, targets, location_idx, temp_k in val_loader:
                inputs = inputs.to(device)  # [batch_size, seq_len, features]
                targets = targets.to(device)  # [batch_size, horizon]
                location_idx = location_idx.to(device)  # [batch_size, horizon]
                temp_k = temp_k.to(device)  # [batch, seq]

                # 调整输入形状
                inputs = inputs.permute(1, 0, 2)  # [seq_len, batch_size, features]

                location_idx = location_idx.permute(1, 0)  # [horizon, batch_size]
                # location_idx = location_idx.unsqueeze(0).expand(inputs.size(0), -1,
                #                                                 -1)  # [seq_len, horizon, batch_size]

                # 前向传播
                with autocast():
                    outputs = model(inputs, location_idx=location_idx, temp_k=temp_k, velocity_field=velocity_field)
                    loss = criterion(outputs, targets)
                    val_loss += loss.item()

                # 反归一化
                outputs_unscaled = target_scaler.inverse_transform(outputs.detach().cpu().numpy())
                targets_unscaled = target_scaler.inverse_transform(targets.detach().cpu().numpy())

                # 计算指标
                mae = np.mean(np.abs(outputs_unscaled - targets_unscaled))
                rmse = np.sqrt(np.mean((outputs_unscaled - targets_unscaled) ** 2))
                val_mae += mae
                val_rmse += rmse

                del inputs, targets, outputs, loss
                torch.cuda.empty_cache()

        return (
            val_loss / len(val_loader),
            val_mae / len(val_loader),
            val_rmse / len(val_loader)
        )


    best_val_loss = float('inf')
    # 训练循环
    for epoch in range(num_epochs):
        avg_train_loss, avg_train_mae, avg_train_rmse = train_one_epoch(model, train_loader)
        avg_val_loss, avg_val_mae, avg_val_rmse = validate_one_epoch(model, val_loader)
        # 保存最优模型
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), 'best_model_seed7.pth')

        # 调用早停机制
        early_stopping(avg_val_loss, model, epoch)
        if early_stopping.early_stop:
            print("Early stopping triggered")
            logging.info(f"Early stopping triggered at epoch {epoch + 1}")
            break

        scheduler.step(avg_val_loss)

        current_lr = optimizer.param_groups[0]['lr']
        logging.info(f"Epoch {epoch + 1}/{num_epochs}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, "
                     f"Train MAE: {avg_train_mae:.4f}, Val MAE: {avg_val_mae:.4f}, "
                     f"Train RMSE: {avg_train_rmse:.4f}, Val RMSE: {avg_val_rmse:.4f}, LR: {current_lr:.6f}")

    # 训练结束后加载最佳权重
    early_stopping.load_best_weights(model)
    print("Loaded best model weights")
    logging.info("Loaded best model weights after training")

    # 测试过程前添加（确保使用最佳权重）
    model.eval()

    # 测试过程修改
    test_loss = 0.0
    test_mae, test_rmse = 0.0, 0.0
    predictions = []
    actuals = []

    with torch.no_grad():
        for inputs, targets, location_idx, temp_k in test_loader:
            inputs = inputs.to(device)
            targets = targets.to(device)
            location_idx = location_idx.to(device)
            temp_k = temp_k.to(device)  # [batch, seq]

            # 调整输入形状
            location_idx = location_idx.permute(1, 0)  # [horizon, batch_size]
            # location_idx = location_idx.unsqueeze(0).expand(inputs.size(0), -1, -1)  # [seq_len, horizon, batch_size]
            inputs = inputs.permute(1, 0, 2)

            # 前向传播
            with autocast():
                outputs = model(inputs, location_idx=location_idx, temp_k=temp_k, velocity_field=velocity_field)
                loss = criterion(outputs, targets)
                test_loss += loss.item()

            # 反归一化
            outputs_unscaled = target_scaler.inverse_transform(
                outputs.cpu().numpy().reshape(-1, 1)
            ).reshape(-1, horizon)
            targets_unscaled = target_scaler.inverse_transform(
                targets.cpu().numpy().reshape(-1, 1)
            ).reshape(-1, horizon)

            # 收集预测结果
            predictions.extend(outputs_unscaled)
            actuals.extend(targets_unscaled)

            # 计算指标
            mae = np.mean(np.abs(outputs_unscaled - targets_unscaled), axis=0)
            rmse = np.sqrt(np.mean((outputs_unscaled - targets_unscaled) ** 2, axis=0))
            test_mae += np.mean(mae)
            test_rmse += np.mean(rmse)

            torch.cuda.empty_cache()

    avg_test_loss = test_loss / len(test_loader)
    avg_test_mae = test_mae / len(test_loader)
    avg_test_rmse = test_rmse / len(test_loader)

    logging.info(f"Test Loss: {avg_test_loss:.4f}, Test MAE: {avg_test_mae:.4f}, Test RMSE: {avg_test_rmse:.4f}")

    # 新增：可视化预测值与实际值的对比
    import matplotlib.pyplot as plt


    # # 反归一化数据
    # predictions = target_scaler.inverse_transform(predictions)
    # actuals = target_scaler.inverse_transform(actuals)


    # # 绘制对比图
    #
    # h = 0  # 第 1 小时
    # plt.figure(figsize=(10, 4))
    # actuals = np.array(actuals)
    # predictions = np.array(predictions)
    # plt.plot(actuals[:500, h], label=f'Actual hour {h + 1}')
    # plt.plot(predictions[:500, h], label=f'Predicted hour {h + 1}')
    # plt.xlabel('Sample Index (first 500)')
    # plt.ylabel('PM2.5')
    # plt.title(f'Hour {h + 1} Forecast: First 500 Samples')
    # plt.legend()
    # plt.grid(True)
    # plt.tight_layout()
    # plt.show()

    # 记录训练结束时间
    end_time = datetime.now().strftime("%Y%m%d_%H%M%S")
    logging.info("Training End Time: {}".format(end_time))