from __future__ import annotations
import os
import json
from dataclasses import dataclass
from typing import Optional, Tuple, Dict, Any, Sequence

import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader

from joblib import load as joblib_load

@dataclass
class DataBundle:
    train_loader: DataLoader
    test_loader: DataLoader
    x_test: torch.Tensor
    y_test: torch.Tensor
    x_ood_test: Optional[torch.Tensor]
    in_dim: int
    num_classes: int
    ood_loader: Optional[DataLoader] = None    
    vectorizer: Any = None
    svd: Any = None
    scaler: Any = None
    info: Dict[str, Any] = None

class NumpyTensorDS(Dataset):
    def __init__(self, X, y, x_dtype=torch.float32, y_dtype=torch.long):
        if isinstance(X, np.ndarray):
            X = torch.from_numpy(X)
        if isinstance(y, np.ndarray):
            y = torch.from_numpy(y)
        X = X.to(dtype=x_dtype)
        y = y.to(dtype=y_dtype)
        self.X, self.y = X, y

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, i):
        return self.X[i], self.y[i]

def _load_pt_tensors(pt_path: str) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    obj = torch.load(pt_path)
    Xtr = obj.get("Xtr") 
    ytr = obj.get("ytr") 
    Xte = obj.get("Xte")
    yte = obj.get("yte")

    return Xtr, ytr, Xte, yte

def _maybe_load_pipelines(base_dir: str, ds_name: str):

    tfidf_path  = os.path.join(base_dir, f"tfidf_{ds_name}.joblib")
    svd_path    = os.path.join(base_dir, f"svd_{ds_name}.joblib")
    scaler_path = os.path.join(base_dir, f"scaler_{ds_name}.joblib")

    vec = joblib_load(tfidf_path)  if os.path.exists(tfidf_path)  else None
    svd = joblib_load(svd_path)    if os.path.exists(svd_path)    else None
    scl = joblib_load(scaler_path) if os.path.exists(scaler_path) else None
    return vec, svd, scl

def _norm_name(s: str) -> str:
    return s.lower().replace("-", "_")

def _get_bs_for(dataset_name: str, data_cfg: dict) -> tuple[int, int]:

    ds = _norm_name(dataset_name)
    per = data_cfg.get("batch_size_per_dataset", {}) or data_cfg.get("batch_size", {}).get("per_dataset", {})
    global_tr = int(data_cfg.get("batch_size_train", 128))
    global_te = int(data_cfg.get("batch_size_test", 4096))

    if ds in per:
        tr = int(per[ds].get("train", global_tr))
        te = int(per[ds].get("test",  global_te))
        return tr, te
    return global_tr, global_te

def make_dataloaders(
    Xtr: torch.Tensor,
    ytr: torch.Tensor,
    Xte: torch.Tensor,
    yte: torch.Tensor,
    *,
    batch_size_train: int = 128,
    batch_size_test: int = 4096,
    num_workers: int = 0,
    pin_memory: bool = True,
    drop_last_train: bool = False,
) -> Tuple[DataLoader, DataLoader]:

    train_ds = NumpyTensorDS(Xtr, ytr, x_dtype=torch.float32, y_dtype=torch.long)
    test_ds  = NumpyTensorDS(Xte, yte, x_dtype=torch.float32, y_dtype=torch.long)

    train_loader = DataLoader(
        train_ds,
        batch_size=batch_size_train,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=pin_memory,
        drop_last=drop_last_train
    )
    test_loader = DataLoader(
        test_ds,
        batch_size=batch_size_test,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=pin_memory
    )
    return train_loader, test_loader

@torch.no_grad()
def get_x_y_from_loader(test_loader, device):
    x_list = []
    y_list = []
    for x, y in test_loader:
        x_list.append(x.to(device))
        y_list.append(y.to(device))
    x_test = torch.cat(x_list, dim=0)
    y_test = torch.cat(y_list, dim=0)
    return x_test, y_test

def _infer_num_classes(ytr: torch.Tensor, yte: torch.Tensor, num_classes_cfg: Optional[int] = None) -> int:
    if num_classes_cfg is not None and num_classes_cfg > 0:
        return int(num_classes_cfg)

    uniques = torch.unique(torch.cat([ytr.view(-1), yte.view(-1)], dim=0))
    return int(uniques.numel())

def load_dataset_bundle(
    *,
    dataset: str,
    data_dir: str,
    batch_size_train: int,
    batch_size_test: int,
    num_workers: int,
    pin_memory: bool,
    num_classes: Optional[int] = None,
    load_pipelines: bool = True,
) -> Tuple[DataLoader, DataLoader, torch.Tensor, torch.Tensor, int, int, Any, Any, Any, Dict[str, Any]]:

    ds_name = dataset.lower().replace("-", "_")
    pt_path = os.path.join(data_dir, f"dataset_{ds_name}.pt")
    Xtr, ytr, Xte, yte = _load_pt_tensors(pt_path)

    if isinstance(Xtr, np.ndarray): Xtr = torch.from_numpy(Xtr)
    if isinstance(Xte, np.ndarray): Xte = torch.from_numpy(Xte)
    if isinstance(ytr, np.ndarray): ytr = torch.from_numpy(ytr)
    if isinstance(yte, np.ndarray): yte = torch.from_numpy(yte)

    in_dim = int(Xtr.shape[1])
    n_classes = _infer_num_classes(ytr, yte, num_classes)

    train_loader, test_loader = make_dataloaders(
        Xtr, ytr, Xte, yte,
        batch_size_train=batch_size_train,
        batch_size_test=batch_size_test,
        num_workers=num_workers,
        pin_memory=pin_memory,
    )
    x_test, y_test = get_x_y_from_loader(test_loader, device="cuda" if torch.cuda.is_available() else "cpu")

    vec, svd, scl = (None, None, None)
    if load_pipelines:
        vec, svd, scl = _maybe_load_pipelines(data_dir, ds_name)

    info = {
        "dataset": dataset,
        "pt_path": pt_path,
        "n_train": int(Xtr.shape[0]),
        "n_test": int(Xte.shape[0]),
    }
    return train_loader, test_loader, x_test, y_test, in_dim, n_classes, vec, svd, scl, info

def build_id_ood_from_config(cfg: Dict[str, Any]) -> DataBundle:
    data_cfg = cfg.get("data", cfg)
    id_dataset = data_cfg["id_dataset"]
    ood_dataset = data_cfg.get("ood_dataset", None)
    id_dir = data_cfg.get("id_dir", ".")
    ood_dir = data_cfg.get("ood_dir", id_dir)
    nwrk   = int(data_cfg.get("num_workers", 0))
    pinmem = bool(data_cfg.get("pin_memory", True))
    num_classes = data_cfg.get("num_classes", None)

    id_tr_bs, id_te_bs = _get_bs_for(id_dataset, data_cfg)

    train_loader, test_loader, x_test, y_test, in_dim, n_classes, vec, svd, scl, info_id = load_dataset_bundle(
        dataset=id_dataset,
        data_dir=id_dir,
        batch_size_train=id_tr_bs,
        batch_size_test=id_te_bs,
        num_workers=nwrk,
        pin_memory=pinmem,
        num_classes=num_classes,
        load_pipelines=True,
    )

    x_ood_test = None
    ood_loader = None
    info = {"id": info_id}

    if ood_dataset is not None:
        ood_tr_bs, ood_te_bs = _get_bs_for(ood_dataset, data_cfg)

        (_ood_train_loader, ood_test_loader, x_ood_test_tmp, _y_ood_dummy, in_dim_ood, _nc_ood, _, _, _, info_ood) = load_dataset_bundle(
            dataset=ood_dataset,
            data_dir=ood_dir,
            batch_size_train=ood_tr_bs,
            batch_size_test=ood_te_bs,
            num_workers=nwrk,
            pin_memory=pinmem,
            num_classes=None,
            load_pipelines=False,
        )
        
        ood_loader = ood_test_loader          
        x_ood_test = x_ood_test_tmp
        info["ood"] = info_ood

    return DataBundle(
        train_loader=train_loader,
        test_loader=test_loader,
        x_test=x_test,
        y_test=y_test,
        x_ood_test=x_ood_test,
        in_dim=in_dim,
        num_classes=n_classes,
        ood_loader=ood_loader,
        vectorizer=vec, svd=svd, scaler=scl,
        info=info,
    )