import os
from pathlib import Path
from typing import List, Tuple, Dict

import numpy as np
import pandas as pd
import torch


def set_seed(seed: int = 2025) -> None:
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def ensure_dir(p: str) -> None:
    Path(p).mkdir(parents=True, exist_ok=True)


def windowize(series: np.ndarray, win_len: int = 512, stride: int = 512) -> Tuple[np.ndarray, np.ndarray]:
    assert series.ndim == 1 
    T = len(series)
    windows, masks = [], []
    pos = 0

    while pos + win_len <= T:
        win = series[pos:pos + win_len]
        mask = np.ones(win_len, dtype=int)
        if mask.sum() > 8:
            windows.append(win)
            masks.append(mask)
        pos += stride

    if pos < T:
        rem = T - pos
        pad_left = win_len - rem
        win = np.concatenate([np.zeros(pad_left, series.dtype), series[pos:]])
        mask = np.concatenate([np.zeros(pad_left, int), np.ones(rem, int)])
        if mask.sum() > 8:
            windows.append(win)
            masks.append(mask)

    return np.stack(windows, axis=0), np.stack(masks, axis=0)


def load_dataset_from_folder(
    folder_path: str,
    win_len: int = 512,
    stride: int = 512,
) -> Tuple[
    List[np.ndarray], List[np.ndarray], List[np.ndarray],
    Dict[str, Tuple[np.ndarray, np.ndarray]], Dict[Tuple[str, str], int]
]:
    train_files: List[np.ndarray] = []
    train_masks: List[np.ndarray] = []
    train_domains: List[np.ndarray] = []
    test_files: Dict[str, Tuple[np.ndarray, np.ndarray]] = {}

    domain_id_map: Dict[Tuple[str, str], int] = {}
    next_did = 0

    for fname in os.listdir(folder_path):
        if not fname.endswith(".csv"):
            continue

        parts = fname.split("_")
        if len(parts) < 7:
            continue

        dataset_name = parts[1]
        domain = parts[4]
        key = (dataset_name, domain)

        if key not in domain_id_map:
            domain_id_map[key] = next_did
            next_did += 1
        did = domain_id_map[key]

        df = pd.read_csv(os.path.join(folder_path, fname))
        if "Data" not in df.columns or "Label" not in df.columns:
            continue

        ts = df["Data"].to_numpy()
        lab = df["Label"].astype(int).to_numpy()

        try:
            train_idx = int(parts[6])
        except Exception:
            continue

        mu, sigma = ts[:train_idx].mean(), ts[:train_idx].std()
        ts_norm = (ts - mu) / sigma if sigma > 1e-12 else (ts - mu)

        train_x, train_m = windowize(ts_norm[:train_idx], win_len=win_len, stride=stride)
        train_files.append(train_x)
        train_masks.append(train_m)
        train_domains.append(np.full(len(train_x), did, dtype=int))

        test_files[fname] = (ts_norm, lab)

    return train_files, train_masks, train_domains, test_files, domain_id_map


def stack_train_lists(
    train_files: List[np.ndarray],
    train_masks: List[np.ndarray],
    train_domains: List[np.ndarray],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

    X_np = np.concatenate(train_files, axis=0)  
    M_np = np.concatenate(train_masks, axis=0) 
    D_np = np.concatenate(train_domains, axis=0)  

    X = torch.from_numpy(X_np).float().unsqueeze(1).to("cuda")  
    M = torch.from_numpy(M_np).long().to("cuda")                 
    D = torch.from_numpy(D_np).long().to("cuda")                 
    return X, M, D
