import scipy.io
import numpy as np
import os
import h5py

import torch
from torch.utils.data import Dataset
from einops import rearrange

try:
    from pde_datasets.data_utils import *
    from utils import *
except:
    from data_utils import *
import xarray as xr
import os
os.environ["OMP_NUM_THREADS"] = "1"
class UnitTransformer:
    """简易的归一化器，将数据归一化为均值为0，方差为1"""
    def __init__(self, tensor):
        # 计算所有样本、所有点的均值和标准差
        self.mean = tensor.mean(dim=(0, 1), keepdim=True)
        self.std = tensor.std(dim=(0, 1), keepdim=True)

    def encode(self, tensor):
        return (tensor - self.mean) / (self.std + 1e-8)

    def decode(self, tensor):
        return tensor * (self.std + 1e-8) + self.mean

    def cuda(self):
        self.mean = self.mean.cuda()
        self.std = self.std.cuda()
        return self
    
    def to(self, device):
        self.mean = self.mean.to(device)
        self.std = self.std.to(device)
        return self

class Burgers(Dataset):
    def __init__(
            self, datapath, nx, sub, n_train=None, n_test=None):
        self.S = int(nx // sub)
        data = scipy.io.loadmat(datapath)
        a = data['a']
        u = data['u']
        if n_train:
            self.a = torch.tensor(a[:n_train, ::sub], dtype=torch.float)
            self.u = torch.tensor(u[:n_train, ::sub], dtype=torch.float)
        if n_test:
            self.a = torch.tensor(a[-n_test:, ::sub], dtype=torch.float)
            self.u = torch.tensor(u[-n_test:, ::sub], dtype=torch.float)
        if n_train and n_test:
            raise ValueError
        if not n_train and not n_test:
            raise ValueError

        self.mesh = torch1dgrid(self.S)

    def __len__(self):
        return self.a.shape[0]

    def __getitem__(self, idx):
        a = self.a[idx]
        return torch.cat([a.unsqueeze(1), self.mesh], dim=1), self.u[idx]


class DarcyFlow(Dataset):
    def __init__(
            self, datapath, nx, sub, offset=0, num=1):
        if sub == 1:
            self.S = int(nx)
        else:
            self.S = int(nx // sub) + 1
        data = scipy.io.loadmat(datapath)
        a = data['coeff']
        u = data['sol']
        self.a = torch.tensor(a[offset: offset + num, ::sub, ::sub], dtype=torch.float)
        self.u = torch.tensor(u[offset: offset + num, ::sub, ::sub], dtype=torch.float)

        self.mesh = torch2dgrid(self.S, self.S)

    def __len__(self):
        return self.a.shape[0]

    def __getitem__(self, idx):
        a = self.a[idx]
        return torch.cat([a.unsqueeze(2), self.mesh], dim=2), self.u[idx]


class Airfoil(Dataset):
    def __init__(self, input1_path, input2_path, output_path, n_train, n_test=None):
        input1 = np.load(input1_path)
        input2 = np.load(input2_path)
        input = np.stack([input1, input2], axis=-1)

        output = np.load(output_path)[:, 4]

        s1 = int(((221 - 1) / 1) + 1)
        s2 = int(((51 - 1) / 1) + 1)

        self.mesh = torch2dgrid(221, 51)

        if not n_train:
            raise ValueError
        if not n_test:
            self.input = torch.tensor(input[:n_train, :s1, :s2], dtype=torch.float)
            self.output = torch.tensor(output[:n_train, :s1, :s2], dtype=torch.float)
        if n_test:
            self.input = torch.tensor(input[n_train:n_train + n_test, :s1, :s2], dtype=torch.float)
            self.output = torch.tensor(output[n_train:n_train + n_test, :s1, :s2], dtype=torch.float)

    def __len__(self):
        return self.input.shape[0]

    def __getitem__(self, idx):
        input = self.input[idx]
        return torch.cat([input, self.mesh], dim=2), self.output[idx]
class PipeDataset(Dataset):
    """
    为 IPOT 模型量身定制的 Pipe (管道流) 数据集加载器。
    
    参考了 Transolver 的数据处理流程（如下采样和重塑），
    但输出格式严格遵循 IPOT 模型的输入要求。
    """
    def __init__(self, 
                 data_path, 
                 n_total,
                 n_train, 
                 downsample_x, 
                 downsample_y, 
                 is_train=True):
        """
        初始化数据集。本方法执行所有耗时的数据加载和预处理操作。
        
        Args:
            data_path (str): 存放 .npy 数据文件的目录路径。
            n_total (int): 使用的总样本数 (训练+测试)。
            n_train (int): 训练样本的数量。
            downsample_x (int): x方向的下采样因子。
            downsample_y (int): y方向的下采样因子。
            is_train (bool): 如果为True，则加载训练集；否则加载测试集。
        """
        super().__init__()
        
        # 1. 加载并组合原始数据
        input_x = np.load(f'{data_path}/Pipe_X.npy')[:n_total]
        input_y = np.load(f'{data_path}/Pipe_Y.npy')[:n_total]
        # 将X,Y坐标堆叠成 (N, H, W, 2) 的张量
        self.coords_raw = torch.tensor(np.stack([input_x, input_y], axis=-1), dtype=torch.float)

        # 加载输出数据 (N, H, W)
        self.solution_raw = torch.tensor(np.load(f'{data_path}/Pipe_Q.npy')[:n_total, 0], dtype=torch.float)
        
        print(f"原始加载数据形状: Coords={self.coords_raw.shape}, Solution={self.solution_raw.shape}")

        # 2. 计算下采样参数
        s1 = int(((self.coords_raw.shape[1] - 1) / downsample_x) + 1)
        s2 = int(((self.coords_raw.shape[2] - 1) / downsample_y) + 1)

        # 3. 切分训练集和测试集 (在下采样之前)
        train_coords = self.coords_raw[:n_train]
        train_solution = self.solution_raw[:n_train]
        
        n_test = n_total - n_train
        test_coords = self.coords_raw[-n_test:]
        test_solution = self.solution_raw[-n_test:]
        
        # 4. 执行下采样
        train_coords = train_coords[:, ::downsample_x, ::downsample_y][:, :s1, :s2]
        train_solution = train_solution[:, ::downsample_x, ::downsample_y][:, :s1, :s2]
        test_coords = test_coords[:, ::downsample_x, ::downsample_y][:, :s1, :s2]
        test_solution = test_solution[:, ::downsample_x, ::downsample_y][:, :s1, :s2]
        
        # 5. 重塑为点云格式 (N, S, C)
        self.train_coords_flat = train_coords.reshape(n_train, -1, 2)
        self.train_solution_flat = train_solution.reshape(n_train, -1)
        self.test_coords_flat = test_coords.reshape(n_test, -1)
        self.test_solution_flat = test_solution.reshape(n_test, -1)
        
        # 6. 创建并拟合归一化器 (必须在训练集上拟合)
        self.coord_normalizer = UnitTransformer(self.train_coords_flat)
        self.solution_normalizer = UnitTransformer(self.train_solution_flat)
        
        # 7. 根据 is_train 标志，选择最终暴露给 __getitem__ 的数据
        self.is_train = is_train
        if self.is_train:
            self.num_samples = n_train
        else:
            self.num_samples = n_test
            
        print(f"数据集处理完成。当前模式: {'训练' if is_train else '测试'}, "
              f"样本数: {self.num_samples}, "
              f"每个样本点数: {self.train_coords_flat.shape[1]}")

    def __len__(self):
        """返回数据集中的样本数量。"""
        return self.num_samples

    def __getitem__(self, idx):
        """
        获取单个样本，并将其格式化为 IPOT 模型期望的 (X, Y) 二元组。
        所有归一化都在这一步完成。
        """
        # 1. 根据模式和索引，选择原始的扁平化数据
        if self.is_train:
            coords = self.train_coords_flat[idx]
            solution = self.train_solution_flat[idx]
        else:
            coords = self.test_coords_flat[idx]
            solution = self.test_solution_flat[idx]
            
        # 2. 对坐标进行归一化
        coords_normalized = self.coord_normalizer.encode(coords)
        
        # 3. 将所有输入信息 (特征+坐标) 拼接成一个单独的张量 X
        # 对于Pipe数据，输入特征就是归一化后的坐标本身
        features = coords_normalized 
        positions = coords_normalized
        model_input_x = torch.cat([features, positions], dim=-1) # Shape: (num_points, 4)
        
        # 4. 准备模型标签 Y
        # 训练时使用归一化后的解，测试时使用原始解（通常做法）
        if self.is_train:
            solution_normalized = self.solution_normalizer.encode(solution)
            model_label_y = solution_normalized.unsqueeze(-1) # Shape: (num_points, 1)
        else:
            # 在评估时，我们通常使用原始尺度的标签进行比较
            model_label_y = solution.unsqueeze(-1) # Shape: (num_points, 1)

        return model_input_x, model_label_y

class Elasticity(Dataset):
    def __init__(self, input1_path, input2_path, output_path, n_train=None, n_test=None):
        input_rr = np.load(input1_path)
        input_rr = torch.tensor(input_rr, dtype=torch.float).permute(1, 0)

        input_xy = np.load(input2_path)
        input_xy = torch.tensor(input_xy, dtype=torch.float).permute(2, 0, 1)

        output = np.load(output_path)
        output = torch.tensor(output, dtype=torch.float).permute(1, 0)

        # some feature engineering
        self.center = torch.tensor([0.0001, 0.0001]).reshape(1, 1, 2)
        angle = torch.atan2(input_xy[:, :, 1] - self.center[:, :, 1], input_xy[:, :, 0] - self.center[:, :, 0])
        radius = torch.norm(input_xy - self.center, dim=-1, p=2)
        input_xy = torch.stack([input_xy[:, :, 0], input_xy[:, :, 1], angle, radius], dim=-1)

        self.mesh = input_xy

        input_rr = input_rr.unsqueeze(1).repeat(1, input_xy.shape[1], 1)
        input = torch.cat([input_rr, input_xy], dim=-1)
        print(input_rr.shape, input_xy.shape, input.shape)

        if not n_train:
            raise ValueError
        if not n_test:
            self.input = input[:n_train]
            self.mesh = self.mesh[:n_train]
            self.output = output[:n_train]
        if n_test:
            self.input = input[n_train: n_train + n_test]
            self.mesh = self.mesh[n_train: n_train + n_test]
            self.output = output[n_train: n_train + n_test]

    def __len__(self):
        return self.input.shape[0]

    def __getitem__(self, idx):
        return self.input[idx], self.mesh[idx], self.output[idx]


class Plasticity(Dataset):
    def __init__(self, datapath, s1, s2, t, n_train=None, n_test=None):
        data = scipy.io.loadmat(datapath)

        input = data['input']
        output = data['output']

        if n_train:
            self.input = torch.tensor(input[:n_train], dtype=torch.float).reshape(n_train, s1, 1, 1, 1).repeat(1, 1, s2,
                                                                                                               t, 1)
            self.output = torch.tensor(output[:n_train], dtype=torch.float)
        if n_test:
            self.input = torch.tensor(input[-n_test:], dtype=torch.float).reshape(n_test, s1, 1, 1, 1).repeat(1, 1, s2,
                                                                                                              t, 1)
            self.output = torch.tensor(output[-n_test:], dtype=torch.float)
        if n_train and n_test:
            raise ValueError
        if not n_train and not n_test:
            raise ValueError

        self.mesh = torch3dgrid(s1, s2, t)

    def __len__(self):
        return self.input.shape[0]

    def __getitem__(self, idx):
        input = self.input[idx]
        return torch.cat([input, self.mesh], dim=3), self.output[idx]


class NavierStokes(Dataset):
    def __init__(self, datapath, nx, sub, T_start=0, T_in=10, T_out=40, n_train=None, n_test=None, is_train=True):
        self.T_start = T_start
        self.T_in = T_in
        self.T_out = T_out
        self.sub = sub
        self.n_train = n_train
        self.n_test = n_test
        self.is_train = is_train
        self.S = nx // sub

        data = h5py.File(datapath)['u']

        if self.is_train:
            self.a = torch.tensor(data[T_start:T_in, ::sub, ::sub, :n_train], dtype=torch.float).transpose(0, 3)
            self.u = torch.tensor(data[T_in:T_in + T_out, ::sub, ::sub, :n_train], dtype=torch.float).transpose(0, 3)
        else:
            self.a = torch.tensor(data[T_start:T_in, ::sub, ::sub, -n_test:], dtype=torch.float).transpose(0, 3)
            self.u = torch.tensor(data[T_in:T_in + T_out, ::sub, ::sub, -n_test:], dtype=torch.float).transpose(0, 3)

        self.mesh = torch2dgrid(self.S, self.S)

        print(self.a.shape, self.u.shape, self.mesh.shape)

    def __len__(self):
        if self.is_train:
            return self.n_train
        else:
            return self.n_test

    def __getitem__(self, idx):
        return torch.cat((self.a[idx], self.mesh), dim=-1), self.u[idx]


class ERA5_temperature(Dataset):
    def __init__(self, datapath, sub, n_train, n_test, is_train):
        self.n_train = n_train
        self.n_test = n_test
        self.is_train = is_train
        np.random.seed(0)

        ds = xr.open_dataset(datapath, engine='cfgrib')
        self.is_train = is_train
        data = np.array(ds["t2m"])

        h = int(((721 - 1) / sub))
        s = int(((1441 - 1) / sub))

        Tn = 7 * int(data.shape[0] / 7)
        data = data[:, :720, :]
        data = data[:, ::sub, ::sub]

        x_data, y_data = [], []

        for i in range(0, data.shape[0] - 14, 7):
            x_data.append(data[i:i + 7])
            y_data.append(data[i + 7:i + 14])

        x_data = np.array(x_data).transpose(0, 2, 3, 1)
        y_data = np.array(y_data).transpose(0, 2, 3, 1)

        x_train = x_data[:n_train]
        y_train = y_data[:n_train]

        x_test = x_data[n_train:]
        y_test = y_data[n_train:]

        self.x_train = torch.tensor(x_train, dtype=torch.float)
        self.y_train = torch.tensor(y_train, dtype=torch.float)

        self.x_test = torch.tensor(x_test, dtype=torch.float)
        self.y_test = torch.tensor(y_test, dtype=torch.float)

        self.mesh = torch2dgrid(h, s, bot=(-0.5, 0), top=(0.5, 2))

        print(self.x_train.shape, self.y_train.shape, self.x_test.shape, self.y_test.shape, self.mesh.shape)

    def __len__(self):
        if self.is_train:
            return self.n_train
        else:
            return self.n_test

    def __getitem__(self, idx):
        if self.is_train:
            return torch.cat((self.x_train[idx], self.mesh), dim=-1), self.y_train[idx]
        else:
            return torch.cat((self.x_test[idx], self.mesh), dim=-1), self.y_test[idx]

