import os
import glob
import numpy as np
import pyvista as pv
from tqdm import tqdm
import torch
from torch_geometric.data import Data
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed
import random
import re

# initial temperature setting
init_temperature = 293.15

@torch.no_grad()
def _row_all_finite(x: torch.Tensor) -> torch.Tensor:
    if x.dim() == 1:
        return torch.isfinite(x)
    return torch.isfinite(x).all(dim=1)

def drop_nan_rows(data, min_keep: int = None, log_prefix: str = "", drop_surface: bool = True):
    assert hasattr(data, "pos") and hasattr(data, "t") and hasattr(data, "q")
    N0 = data.pos.size(0)

    q = data.q
    if q.dim() == 2 and q.size(1) == 1:
        q_mask = _row_all_finite(q.squeeze(1))
    else:
        q_mask = _row_all_finite(q)

    mask_main = _row_all_finite(data.pos) & _row_all_finite(data.t) & q_mask
    keep_idx = mask_main.nonzero(as_tuple=False).squeeze(1)

    if keep_idx.numel() == 0:
        raise ValueError(f"{log_prefix} all points are invalid (NaN/Inf).")

    if min_keep is not None and keep_idx.numel() < min_keep:
        raise ValueError(f"{log_prefix} valid points {keep_idx.numel()} < min_keep {min_keep}. "
                         f"Increase min_keep or fix source NaNs.")

    data.pos = data.pos[keep_idx]
    data.t = data.t[keep_idx]
    data.q = data.q[keep_idx]

    if drop_surface and hasattr(data, "surf") and hasattr(data, "surf_pos"):
        Ns0 = data.surf_pos.size(0)
        mask_surf = _row_all_finite(data.surf_pos) & _row_all_finite(data.surf)
        keep_sidx = mask_surf.nonzero(as_tuple=False).squeeze(1)
        if keep_sidx.numel() > 0:
            data.surf_pos = data.surf_pos[keep_sidx]
            data.surf = data.surf[keep_sidx]
        else:
            delattr(data, "surf_pos")
            delattr(data, "surf")

        dropped_surf = Ns0 - keep_sidx.numel()
        if dropped_surf > 0:
            print(f"{log_prefix}[NaN] drop {dropped_surf}/{Ns0} surface rows ({dropped_surf / max(1, Ns0):.2%})")

    dropped = N0 - keep_idx.numel()
    data.nan_drop_ratio = float(dropped) / float(max(1, N0))
    # if dropped > 0:
    #     print(f"{log_prefix}[NaN] drop {dropped}/{N0} main rows ({data.nan_drop_ratio:.2%})")
    return data

def load_vtu_point_data(vtu_path):
    mesh = pv.read(vtu_path)
    return mesh.points, {k: mesh.point_data[k] for k in mesh.point_data}


def load_combined_vtu(key, q_path, t_path, surf_path=None):
    q_pts, q_data = load_vtu_point_data(q_path)
    t_pts, t_data = load_vtu_point_data(t_path)

    def sorted_t_keys(t_data):
        keys = list(t_data.keys())
        def _k(x):
            m = re.findall(r'\d+', x)
            return (int(m[0]) if m else 10**9, x)
        return sorted(keys, key=_k)

    if not np.allclose(q_pts, t_pts, rtol=1e-5, atol=1e-8):
        raise ValueError(f"Point mismatch at {key}")

    pos = torch.tensor(q_pts, dtype=torch.float)
    t_keys = sorted_t_keys(t_data)
    t_stack = np.stack([t_data[k] for k in t_keys], axis=1)
    t = torch.tensor(t_stack, dtype=torch.float)

    q_raw = next(iter(q_data.values()))
    q = torch.tensor(q_raw[:, None] if q_raw.ndim == 1 else q_raw, dtype=torch.float)

    data = Data(pos=pos, t=t, q=q)

    if surf_path is not None:
        surf_pts, surf_data = load_vtu_point_data(surf_path)
        surf_keys = sorted_t_keys(surf_data)
        surf_array = np.stack([surf_data[k] for k in surf_keys], axis=1) if len(surf_keys) > 1 \
            else next(iter(surf_data.values()))
        data.surf = torch.tensor(surf_array, dtype=torch.float)
        data.surf_pos = torch.tensor(surf_pts, dtype=torch.float)

    return data


def downsample_data(data, downsample_count=None, ratio=None, surf_downsample_count=None, surf_ratio=None,
                    data_type=None, seed=42):
    rng = np.random.default_rng(seed)
    num_points = data.pos.size(0)

    if downsample_count is not None:
        keep = min(num_points, downsample_count)
    elif ratio is not None:
        if not (0 < ratio <= 1):
            raise ValueError("ratio must be between 0 and 1")
        keep = max(1, int(num_points * ratio))
    else:
        keep = num_points

    if data_type == 'unstructured_data':
        idx = rng.choice(num_points, keep, replace=False)
    else:
        idx = np.linspace(0, num_points - 1, num=keep, dtype=int)

    downsampled = Data(
        pos=data.pos[idx],
        t=data.t[idx],
        q=(data.q[idx] if data.q.dim() > 0 else torch.tensor(data.q, dtype=torch.float)[idx])
    )

    if hasattr(data, 'surf'):
        surf_points = data.surf.size(0)
        if surf_downsample_count is not None:
            surf_keep = min(surf_points, surf_downsample_count)
        elif surf_ratio is not None:
            if not (0 < surf_ratio <= 1):
                raise ValueError("surf_ratio must be between 0 and 1")
            surf_keep = max(1, int(surf_points * surf_ratio))
        else:
            surf_keep = surf_points

        if data_type == 'unstructured_data':
            surf_idx = rng.choice(surf_points, surf_keep, replace=False)
        else:
            surf_idx = np.linspace(0, surf_points - 1, num=surf_keep, dtype=int)
        downsampled.surf = data.surf[surf_idx]
        downsampled.surf_pos = data.surf_pos[surf_idx]

    return downsampled

_FREQ_RE = re.compile(r'(.+?)_([0-9]+)(kHz|Hz|MHz)_([0-9]+)A$')

def _parse_prefix_and_freq_token(key: str):
    m = _FREQ_RE.match(key)
    if not m:
        toks = re.findall(r'([0-9]+(?:kHz|Hz|MHz))', key)
        freq = toks[-1] if toks else None
        if freq and f"_{freq}_" in key:
            prefix = key.rsplit(f"_{freq}_", 1)[0]
        else:
            prefix = key
        return prefix, freq
    prefix, v, unit, amp = m.groups()
    return prefix, f"{int(v)}{unit}"

def freq_index_to_token(idx: int) -> str:
    if not (1 <= idx <= 10):
        raise ValueError("target_freq index must be in [1, 10]")
    return f"{idx * idx}kHz"

def select_tasks_by_fixed_frequency(root_dir: str,
                                    data_type: str,
                                    use_surf: bool,
                                    target_freq_index: int,
                                    train_num: int = None,
                                    test_num: int = None):
    sim_dirs = [os.path.join(root_dir, d)
                for d in os.listdir(root_dir)
                if os.path.isdir(os.path.join(root_dir, d))]
    groups = {}  # prefix -> {freq_token: (key,q,t,surf)}
    for sim in sim_dirs:
        data_dir = os.path.join(sim, data_type)
        q_files = sorted(glob.glob(os.path.join(data_dir, '*_Q.vtu')))
        t_files = sorted(glob.glob(os.path.join(data_dir, '*_T.vtu')))
        if len(q_files) == 0 or len(q_files) != len(t_files):
            continue
        q_map = {os.path.basename(p).replace('_Q.vtu', ''): p for p in q_files}
        t_map = {os.path.basename(p).replace('_T.vtu', ''): p for p in t_files}
        surf_map = {}
        if use_surf:
            surf_dir = os.path.join(sim, 'surf_input')
            s_files = sorted(glob.glob(os.path.join(surf_dir, '*_Tsurf.vtu')))
            surf_map = {os.path.basename(p).replace('_Tsurf.vtu', ''): p for p in s_files}

        for key in set(q_map).intersection(t_map):
            if use_surf and key not in surf_map:
                continue
            prefix, ftok = _parse_prefix_and_freq_token(key)
            if ftok is None:
                continue
            groups.setdefault(prefix, {})[ftok] = (key, q_map[key], t_map[key], surf_map.get(key) if use_surf else None)

    if not groups:
        raise RuntimeError("No valid (prefix, freq) groups found")

    target_token = freq_index_to_token(target_freq_index)

    train_candidates, test_candidates = [], []
    for prefix, fmap in groups.items():
        if target_token in fmap:
            train_candidates.append(fmap[target_token])
        for ft, rec in fmap.items():
            if ft != target_token:
                test_candidates.append(rec)

    random.shuffle(train_candidates)
    random.shuffle(test_candidates)
    if train_num is not None:
        train_candidates = train_candidates[:int(train_num)]
    if test_num is not None:
        test_candidates = test_candidates[:int(test_num)]
    return train_candidates, test_candidates

def load_all_data(root_dir,
                  max_workers=8,
                  data_num=None,
                  downsample_count=None,
                  downsample_ratio=None,
                  surf_downsample_ratio=None,
                  surf_downsample_count=None,
                  data_type='unstructured_data',
                  normalize=True,
                  unit_normalize=True,
                  use_surf=True,
                  selected_tasks=None):
    tasks = []

    if selected_tasks is not None:
        tasks = list(selected_tasks)
    else:
        sim_dirs = [os.path.join(root_dir, d)
                    for d in os.listdir(root_dir)
                    if os.path.isdir(os.path.join(root_dir, d))]

        for sim in sim_dirs:
            data_dir = os.path.join(sim, data_type)
            q_files = sorted(glob.glob(os.path.join(data_dir, '*_Q.vtu')))
            t_files = sorted(glob.glob(os.path.join(data_dir, '*_T.vtu')))

            if len(q_files) != len(t_files):
                print(f"Skipping {sim}: `{data_type}` Mismatch in number of documents")
                continue

            q_map = {os.path.basename(p).replace('_Q.vtu', ''): p for p in q_files}
            t_map = {os.path.basename(p).replace('_T.vtu', ''): p for p in t_files}
            surf_map = {}

            if use_surf:
                surf_dir = os.path.join(sim, 'surf_input')
                surf_files = sorted(glob.glob(os.path.join(surf_dir, '*_Tsurf.vtu')))
                surf_map = {os.path.basename(p).replace('_Tsurf.vtu', ''): p for p in surf_files}

            for key in sorted(q_map):
                if key in t_map and (not use_surf or key in surf_map):
                    tasks.append((key, q_map[key], t_map[key], surf_map.get(key)))
                else:
                    print(f"Missing pair or surface for key: {key} in {sim}")

    if data_num is not None:
        data_num = min(data_num, len(tasks))
        tasks = random.sample(tasks, data_num)
    if not tasks:
        raise RuntimeError("No data tasks found")

    data_list = []
    skipped = 0
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = {
            executor.submit(load_combined_vtu, key, q, t, surf): (key, surf)
            for key, q, t, surf in tasks
        }
        for fut in tqdm(as_completed(futures), total=len(futures), desc="Loading all files"):
            key, surf = futures[fut]
            try:
                data = fut.result()
                if downsample_count is not None:
                    min_keep = int(downsample_count)
                elif downsample_ratio is not None:
                    min_keep = max(1, int(data.pos.size(0) * float(downsample_ratio)))
                else:
                    min_keep = None
                data = drop_nan_rows(data,
                                     min_keep=min_keep,
                                     log_prefix=f"[{key}] ",
                                     drop_surface=True)
                data = downsample_data(data,
                                       downsample_count=downsample_count,
                                       ratio=downsample_ratio,
                                       surf_downsample_count=surf_downsample_count,
                                       surf_ratio=surf_downsample_ratio,
                                       data_type=data_type)
                ok_main = (data.pos.numel() > 0) and (data.t.numel() > 0) and (data.q.numel() > 0)
                ok_surf = (not use_surf) or (hasattr(data, "surf") and data.surf.numel() > 0)

                if ok_main and ok_surf:
                    data.sim_key = key
                    data_list.append(data)
                else:
                    skipped += 1

            except Exception as e:
                print(f"Skipping {key}: {e}")
                skipped += 1

    if not data_list:
        raise RuntimeError("No data loaded")

    print(f"Loaded {len(data_list)} samples, Skipped {skipped}")

    stats = {}
    if normalize:
        if unit_normalize:
            for d in data_list:
                d.t = d.t - init_temperature
                mean_t = d.t.mean(0, keepdim=True)
                std_t = d.t.std(0, unbiased=False, keepdim=True) + 1e-8
                d.t = (d.t - mean_t) / std_t

                mean_q = d.q.mean(0, keepdim=True)
                std_q = d.q.std(0, unbiased=False, keepdim=True) + 1e-8
                d.q = (d.q - mean_q) / std_q

                if use_surf and hasattr(d, 'surf'):
                    d.surf = d.surf - init_temperature
                    mean_surf = d.surf.mean(0, keepdim=True)
                    std_surf = d.surf.std(0, unbiased=False, keepdim=True) + 1e-8
                    d.surf = (d.surf - mean_surf) / std_surf

            all_pos = torch.cat([d.pos for d in data_list], dim=0)
            mean_pos = all_pos.mean(0, keepdim=True)
            std_pos = all_pos.std(0, unbiased=False, keepdim=True) + 1e-8
            for d in data_list:
                d.pos = (d.pos - mean_pos) / std_pos

            if use_surf and hasattr(d, 'surf_pos'):
                all_surf_pos = torch.cat([d.surf_pos for d in data_list], dim=0)
                mean_surf_pos = all_surf_pos.mean(0, keepdim=True)
                std_surf_pos = all_surf_pos.std(0, unbiased=False, keepdim=True) + 1e-8
                for d in data_list:
                    d.surf_pos = (d.surf_pos - mean_surf_pos) / std_surf_pos

            stats = {'mean_t': mean_t, 'std_t': std_t, 'mean_pos': mean_pos,
                     'std_pos': std_pos, 'mean_q': mean_q, 'std_q': std_q}
            if use_surf:
                stats.update({'mean_surf': mean_surf, 'std_surf': std_surf,
                              'mean_surf_pos': mean_surf_pos, 'std_surf_pos': std_surf_pos})
        else:
            total_main_points = 0
            total_surf_points = 0

            mean_t, mean_pos, mean_q = None, None, None
            mean_surf, mean_surf_pos = None, None
            for d in data_list:
                t_np = d.t.numpy()
                pos_np = d.pos.numpy()
                q_np = d.q.numpy()

                n = t_np.shape[0]

                if mean_t is None:
                    mean_t, mean_pos, mean_q = t_np.mean(0), pos_np.mean(0), q_np.mean(0)
                else:
                    new_total = total_main_points + n
                    mean_t = (mean_t * total_main_points + t_np.sum(0)) / new_total
                    mean_pos = (mean_pos * total_main_points + pos_np.sum(0)) / new_total
                    mean_q = (mean_q * total_main_points + q_np.sum(0)) / new_total
                total_main_points += n

                if use_surf and hasattr(d, 'surf'):
                    surf_np = d.surf.numpy()
                    surf_pos_np = d.surf_pos.numpy()
                    m = surf_np.shape[0]
                    if mean_surf is None:
                        mean_surf = surf_np.mean(0)
                        mean_surf_pos = surf_pos_np.mean(0)
                    else:
                        new_surf_total = total_surf_points + m
                        mean_surf = (mean_surf * total_surf_points + surf_np.sum(0)) / new_surf_total
                        mean_surf_pos = (mean_surf_pos * total_surf_points + surf_pos_np.sum(0)) / new_surf_total
                    total_surf_points += m

            var_t = np.zeros_like(mean_t)
            var_pos = np.zeros_like(mean_pos)
            var_q = np.zeros_like(mean_q)
            var_surf = np.zeros_like(mean_surf) if use_surf else None
            var_surf_pos = np.zeros_like(mean_surf_pos) if use_surf else None

            for d in data_list:
                t = d.t
                pos = d.pos
                q = d.q
                var_t += ((t - torch.from_numpy(mean_t)) ** 2).sum(0).numpy()
                var_pos += ((pos - torch.from_numpy(mean_pos)) ** 2).sum(0).numpy()
                var_q += ((q - torch.from_numpy(mean_q)) ** 2).sum(0).numpy()

                if use_surf and hasattr(d, 'surf'):
                    surf = d.surf
                    surf_pos = d.surf_pos
                    var_surf += ((surf - torch.from_numpy(mean_surf)) ** 2).sum(0).numpy()
                    var_surf_pos += ((surf_pos - torch.from_numpy(mean_surf_pos)) ** 2).sum(0).numpy()

            std_t = np.sqrt(np.clip(var_t / total_main_points, a_min=0, a_max=None))
            std_pos = np.sqrt(np.clip(var_pos / total_main_points, a_min=0, a_max=None))
            std_q = np.sqrt(np.clip(var_q / total_main_points, a_min=0, a_max=None))
            std_surf = np.sqrt(np.clip(var_surf / total_surf_points, a_min=0, a_max=None)) if use_surf else None
            std_surf_pos = np.sqrt(np.clip(var_surf_pos / total_surf_points, a_min=0, a_max=None)) if use_surf else None

            for d in data_list:
                d.t = (d.t - torch.from_numpy(mean_t).float()) / (torch.from_numpy(std_t).float() + 1e-8)
                d.pos = (d.pos - torch.from_numpy(mean_pos).float()) / (torch.from_numpy(std_pos).float() + 1e-8)
                d.q = (d.q - torch.from_numpy(mean_q).float()) / (torch.from_numpy(std_q).float() + 1e-8)
                if use_surf and hasattr(d, 'surf'):
                    d.surf = (d.surf - torch.from_numpy(mean_surf).float()) / (torch.from_numpy(std_surf).float() + 1e-8)
                    d.surf_pos = (d.surf_pos - torch.from_numpy(mean_surf_pos).float()) / (torch.from_numpy(std_surf_pos).float() + 1e-8)

            stats = {'mean_t': mean_t, 'std_t': std_t, 'mean_pos': mean_pos,
                     'std_pos': std_pos, 'mean_q': mean_q, 'std_q': std_q}
            if use_surf:
                stats.update({'mean_surf': mean_surf, 'std_surf': std_surf,
                              'mean_surf_pos': mean_surf_pos, 'std_surf_pos': std_surf_pos})

    return data_list, stats


def load_data_fixed_frequency(root_dir: str,
                              data_type: str,
                              use_surf: bool,
                              target_freq_index: int,
                              train_num: int = None,
                              test_num: int = None,
                              # Be consist with load_all_data
                              max_workers: int = 8,
                              downsample_count=None,
                              downsample_ratio=None,
                              surf_downsample_ratio=None,
                              surf_downsample_count=None,
                              normalize=True,
                              unit_normalize=True):
    train_tasks, test_tasks = select_tasks_by_fixed_frequency(
        root_dir=root_dir,
        data_type=data_type,
        use_surf=use_surf,
        target_freq_index=target_freq_index,
        train_num=train_num,
        test_num=test_num
    )
    all_tasks = train_tasks + test_tasks

    loaded, stats = load_all_data(
        root_dir,
        max_workers=max_workers,
        data_num=None,
        downsample_count=downsample_count,
        downsample_ratio=downsample_ratio,
        surf_downsample_ratio=surf_downsample_ratio,
        surf_downsample_count=surf_downsample_count,
        data_type=data_type,
        normalize=normalize,
        unit_normalize=unit_normalize,
        use_surf=use_surf,
        selected_tasks=all_tasks
    )

    train_keys = {k for (k, *_r) in train_tasks}
    test_keys  = {k for (k, *_r) in test_tasks}
    train_graphs, test_graphs = [], []
    for d in loaded:
        if getattr(d, 'sim_key', None) in train_keys:
            train_graphs.append(d)
        elif getattr(d, 'sim_key', None) in test_keys:
            test_graphs.append(d)

    print(f"[FixedFreq] target={freq_index_to_token(target_freq_index)} | train={len(train_graphs)} | test={len(test_graphs)}")
    return train_graphs, test_graphs, stats

if __name__ == '__main__':
    train_graphs, test_graphs, stats = load_data_fixed_frequency(
        root_dir="../Aletheia/typeI_double-layer",
        data_type='unstructured_data',  # 'structured_data' / 'unstructured_data'
        use_surf=True,
        target_freq_index=1,  # 1..10
        train_num=480,
        test_num=120,
        max_workers=8,
        downsample_count=8000,
        surf_downsample_count=8000,
        normalize=True,
        unit_normalize=True
    )

    ds_structured, stats_structured = load_all_data(
        "/dataset path/",
        max_workers=8,
        data_num=200,
        downsample_count=8000,
        surf_downsample_count=8000,
        data_type='structured_data'
    )
    print(f"Structured total: {len(ds_structured)}")


    ds_unstructured, stats_unstructured = load_all_data(
        "/dataset path/",
        max_workers=8,
        data_num=100,
        downsample_count=8000,
        surf_downsample_ratio=0.5,
        data_type='unstructured_data'
    )
    print(f"Unstructured total: {len(ds_unstructured)}")
