import json
import os
from copy import deepcopy
from pathlib import Path
from typing import Any, Dict, List, Optional

import numpy as np
from torch.utils.data import Dataset

from src.model.regress_lm import core
from src.model.regress_lm.vocabs import SentencePieceVocab


class Binary_fit_Dataset(Dataset):
    def __init__(
        self,
        data_dir: str,
        split: str = "train",
        max_samples: Optional[int] = None,
        random_seed: int = 42,
        dataset_name: Optional[str] = None,
        compute_max_seq_len: bool = True,
        num_nan_policy: str = "mean",
        cat_nan_policy: str = "new",
        cat_policy: str = "indices",
        num_policy: str = "none",
        normalization: str = "standard",
        n_bins: int = 2,
        use_float: bool = False,
    ):
        if os.path.isabs(data_dir):
            self.data_dir = Path(data_dir)
        else:
            self.data_dir = Path(os.path.abspath(data_dir))
        self.split = split
        self.max_samples = max_samples
        self.random_seed = random_seed
        self.dataset_name = dataset_name
        self.compute_max_seq_len = compute_max_seq_len

        self.num_nan_policy = num_nan_policy
        self.cat_nan_policy = cat_nan_policy
        self.cat_policy = cat_policy
        self.num_policy = num_policy
        self.normalization = normalization
        self.n_bins = n_bins
        self.use_float = use_float

        self.tokenizer = None
        if self.compute_max_seq_len:
            try:
                self.tokenizer = SentencePieceVocab.from_t5()
            except Exception:
                self.compute_max_seq_len = False

        np.random.seed(random_seed)

        self.load_data()

        self.process_data()

    def load_data(self):
        if self.dataset_name:
            self.load_single_dataset(self.dataset_name)
        else:
            self.load_all_datasets()

    def load_single_dataset(self, dataset_name: str):
        dataset_dir = self.data_dir / dataset_name

        if not dataset_dir.exists():
            raise FileNotFoundError(f"Dataset directory not found: {dataset_dir}")

        info_file = dataset_dir / "info.json"
        if not info_file.exists():
            raise FileNotFoundError(f"info.json file not found: {info_file}")

        with open(info_file, "r", encoding="utf-8") as f:
            self.info = json.load(f)

        self.metadata = {
            "name": self.info.get("name", dataset_name),
            "dimension": self.info.get("n_num_features", 0)
            + self.info.get("n_cat_features", 0),
        }

        self.num_feature_names = self.info.get("num_feature_intro", {})
        self.cat_feature_names = self.info.get("cat_feature_intro", {})

        x_num_file = dataset_dir / f"N_{self.split}.npy"
        x_num_file_norm = dataset_dir / f"N_train.npy"
        x_cat_file = dataset_dir / f"C_{self.split}.npy"
        x_cat_file_norm = dataset_dir / f"C_train.npy"
        y_file = dataset_dir / f"y_{self.split}.npy"
        train_y_file = dataset_dir / f"y_train.npy"

        if not x_num_file.exists() and not x_cat_file.exists():
            raise FileNotFoundError(f"Feature files not found: {x_num_file} or {x_cat_file}")
        if not y_file.exists():
            raise FileNotFoundError(f"Label file not found: {y_file}")

        self.x_num_data = None
        self.x_num_data_norm = None
        self.x_cat_data = None
        self.x_cat_data_norm = None

        if x_num_file.exists():
            self.x_num_data = np.load(x_num_file, allow_pickle=True)
            if len(self.x_num_data.shape) == 1:
                self.x_num_data = self.x_num_data.reshape(-1, 1)

        if x_num_file_norm.exists():
            self.x_num_data_norm = np.load(x_num_file_norm, allow_pickle=True)
            if len(self.x_num_data_norm.shape) == 1:
                self.x_num_data_norm = self.x_num_data_norm.reshape(-1, 1)

        if x_cat_file.exists():
            self.x_cat_data = np.load(x_cat_file, allow_pickle=True)
            if len(self.x_cat_data.shape) == 1:
                self.x_cat_data = self.x_cat_data.reshape(-1, 1)

        if x_cat_file_norm.exists():
            self.x_cat_data_norm = np.load(x_cat_file_norm, allow_pickle=True)
            if len(self.x_cat_data_norm.shape) == 1:
                self.x_cat_data_norm = self.x_cat_data_norm.reshape(-1, 1)

        self.y_data = np.load(y_file, allow_pickle=True)
        if len(self.y_data.shape) == 1:
            self.y_data = self.y_data.reshape(-1, 1)

        self.train_y_data = np.load(train_y_file, allow_pickle=True)
        if len(self.train_y_data.shape) == 1:
            self.train_y_data = self.train_y_data.reshape(-1, 1)

    def load_all_datasets(self):
        self.datasets = {}
        self.dataset_info = {}
        self.x_data = []
        self.y_data = []

        for item in self.data_dir.iterdir():
            if item.is_dir() and not item.name.startswith("."):
                dataset_name = item.name

                info_file = item / "info.json"
                if not info_file.exists():
                    continue

                try:
                    with open(info_file, "r", encoding="utf-8") as f:
                        info = json.load(f)

                    metadata = {
                        "name": info.get("name", dataset_name),
                        "dimension": info.get("n_num_features", 0)
                        + info.get("n_cat_features", 0),
                    }

                    if info.get("task_type") != "regression":
                        continue

                    x_num_file = item / f"N_{self.split}.npy"
                    x_cat_file = item / f"C_{self.split}.npy"
                    y_file = item / f"y_{self.split}.npy"

                    if not x_num_file.exists() and not x_cat_file.exists():
                        continue
                    if not y_file.exists():
                        continue

                    x_num_data = None
                    x_cat_data = None

                    if x_num_file.exists():
                        x_num_data = np.load(x_num_file, allow_pickle=True)
                        if len(x_num_data.shape) == 1:
                            x_num_data = x_num_data.reshape(-1, 1)

                    if x_cat_file.exists():
                        x_cat_data = np.load(x_cat_file, allow_pickle=True)
                        if len(x_cat_data.shape) == 1:
                            x_cat_data = x_cat_data.reshape(-1, 1)

                    y_data = np.load(y_file, allow_pickle=True)
                    if len(y_data.shape) == 1:
                        y_data = y_data.reshape(-1, 1)

                    self.datasets[dataset_name] = {
                        "x_num_data": x_num_data,
                        "x_cat_data": x_cat_data,
                        "y_data": y_data,
                        "info": info,
                        "metadata": metadata,
                    }

                    self.dataset_info[dataset_name] = {
                        "name": info.get("name", dataset_name),
                        "dimension": info.get("n_num_features", 0)
                        + info.get("n_cat_features", 0),
                        "split": self.split,
                        "num_samples": len(y_data),
                    }

                except Exception:
                    continue

        if not self.datasets:
            raise ValueError("No valid regression datasets found")

    def process_data(self):
        if self.dataset_name:
            self.process_single_dataset()
        else:
            self.process_all_datasets()

    def process_single_dataset(self):
        N_data = {"train": self.x_num_data} if self.x_num_data is not None else None
        N_data_norm = {"train": self.x_num_data_norm} if self.x_num_data_norm is not None else None
        C_data = {"train": self.x_cat_data} if self.x_cat_data is not None else None
        C_data_norm = {"train": self.x_cat_data_norm} if self.x_cat_data_norm is not None else None
        y_data = {"train": self.y_data}
        y_train_data = {"train": self.train_y_data}

        N_data, N_data_norm, C_data, C_data_norm, num_new_value, imputer, cat_new_value = self.data_nan_process(
            N_data, N_data_norm, C_data, C_data_norm, self.num_nan_policy, self.cat_nan_policy
        )

        N_data, num_encoder = self.num_enc_process(
            N_data, self.num_policy, self.n_bins, y_data["train"], is_regression=True
        )

        N_data, N_data_norm, C_data, ord_encoder, mode_values, cat_encoder = self.data_enc_process(
            N_data, N_data_norm, C_data, C_data_norm, self.cat_policy, y_data["train"]
        )

        N_data, normalizer = self.data_norm_process(
            N_data, N_data_norm, self.normalization, self.random_seed
        )
        self.categories = self.get_categories(C_data)

        if N_data is not None and C_data is not None:
            self.x_data = np.concatenate([N_data["train"], C_data["train"]], axis=1)
        elif N_data is not None:
            self.x_data = N_data["train"]
        elif C_data is not None:
            self.x_data = C_data["train"]
        else:
            raise ValueError("No feature data found")
        self.train_y_data, self.train_y_info = self.data_label_process(y_train_data["train"], y_train_data["train"], is_regression=True)
        self.y_data, self.y_info = self.data_label_process(y_data["train"], y_train_data["train"], is_regression=True)

        self.num_new_value = num_new_value
        self.imputer = imputer
        self.cat_new_value = cat_new_value
        self.num_encoder = num_encoder
        self.ord_encoder = ord_encoder
        self.mode_values = mode_values
        self.cat_encoder = cat_encoder
        self.normalizer = normalizer
        self.dimension = self.x_data.shape[1]

    def data_label_process(self, y_data, train_y_data, is_regression):
        y = deepcopy(y_data)
        train_y = deepcopy(train_y_data)
        mean, std = np.mean(train_y), np.std(train_y)
        if std == 0:
            std = 1e-8
        y = (y - mean) / std
        train_y = (train_y - mean) / std
        y_min = np.min(train_y)
        y_max = np.max(train_y)
        y = (y - y_min) / (y_max - y_min + 1e-8)
        info = {"policy": "mean_std", "mean": mean, "std": std, "y_min": y_min, "y_max": y_max}
        return y, info

    def process_all_datasets(self):
        for dataset_name, dataset in self.datasets.items():
            N_data = (
                {"train": dataset["x_num_data"]}
                if dataset["x_num_data"] is not None
                else None
            )
            C_data = (
                {"train": dataset["x_cat_data"]}
                if dataset["x_cat_data"] is not None
                else None
            )
            y_data = {"train": dataset["y_data"]}

            N_data, C_data, num_new_value, imputer, cat_new_value = (
                self.data_nan_process(
                    N_data, C_data, self.num_nan_policy, self.cat_nan_policy
                )
            )

            N_data, num_encoder = self.num_enc_process(
                N_data,
                self.num_policy,
                self.n_bins,
                y_data["train"],
                is_regression=True,
            )

            N_data, C_data, ord_encoder, mode_values, cat_encoder = (
                self.data_enc_process(N_data, C_data, self.cat_policy, y_data["train"])
            )

            N_data, normalizer = self.data_norm_process(
                N_data, self.normalization, self.random_seed
            )

            if N_data is not None and C_data is not None:
                x_data = np.concatenate([N_data["train"], C_data["train"]], axis=1)
            elif N_data is not None:
                x_data = N_data["train"]
            elif C_data is not None:
                x_data = C_data["train"]
            else:
                raise ValueError(f"Dataset {dataset_name} has no feature data")

            dataset["x_data"] = x_data
            dataset["y_data"],dataset["y_info"] = self.data_label_process(y_data["train"], is_regression=True)
            dataset["categories"] = self.get_categories(C_data)

            dataset["num_new_value"] = num_new_value
            dataset["imputer"] = imputer
            dataset["cat_new_value"] = cat_new_value
            dataset["num_encoder"] = num_encoder
            dataset["ord_encoder"] = ord_encoder
            dataset["mode_values"] = mode_values
            dataset["cat_encoder"] = cat_encoder
            dataset["normalizer"] = normalizer
            dataset["dimension"] = x_data.shape[1]
            for i in range(len(x_data)):
                self.x_data.append(x_data[i])
                self.y_data.append(y_data["train"][i])

    def data_nan_process(
        self,
        N_data,
        N_data_norm,
        C_data,
        C_data_norm,
        num_nan_policy,
        cat_nan_policy,
        num_new_value=None,
        imputer=None,
        cat_new_value=None,
    ):
        if N_data is None:
            N = None
            N_norm = None
        else:
            N = deepcopy(N_data)
            N_norm = deepcopy(N_data_norm)
            N = {k: v.astype(float) for k, v in N.items()}
            N_norm = {k: v.astype(float) for k, v in N_norm.items()}

            num_nan_masks = {k: np.isnan(v) for k, v in N.items()}
            num_nan_masks_norm = {k: np.isnan(v) for k, v in N_norm.items()}
            if num_new_value is None:
                if num_nan_policy == "mean":
                    num_new_value = np.nanmean(N_norm["train"], axis=0)
                elif num_nan_policy == "median":
                    num_new_value = np.nanmedian(N_norm["train"], axis=0)
                else:
                    raise ValueError(f"Unknown numerical NaN policy: {num_nan_policy}")
                if np.isnan(num_new_value.astype(float)).any():
                    num_new_value = np.nan_to_num(num_new_value)

            if any(x.any() for x in num_nan_masks.values()):
                for k, v in N.items():
                    num_nan_indices = np.where(num_nan_masks[k])
                    v[num_nan_indices] = np.take(num_new_value, num_nan_indices[1])

            if any(x.any() for x in num_nan_masks_norm.values()):
                for k, v in N_norm.items():
                    num_nan_indices = np.where(num_nan_masks_norm[k])
                    v[num_nan_indices] = np.take(num_new_value, num_nan_indices[1])

        if C_data is None:
            C = None
            C_norm = None
        else:
            C = deepcopy(C_data)
            C_norm = deepcopy(C_data_norm)
            C = {k: v.astype(str) for k, v in C.items()}
            C_norm = {k: v.astype(str) for k, v in C_norm.items()}

            cat_nan_masks = {
                k: np.isin(v, ["nan", "NaN", "", None]) for k, v in C.items()
            }
            cat_nan_masks_norm = {
                k: np.isin(v, ["nan", "NaN", "", None]) for k, v in C_norm.items()
            }
            if cat_nan_policy == "new":
                if cat_new_value is None:
                    cat_new_value = "___null___"
                    imputer = None
            elif cat_nan_policy == "most_frequent":
                if imputer is None:
                    cat_new_value = None
                    from sklearn.impute import SimpleImputer

                    imputer = SimpleImputer(strategy="most_frequent")
                    imputer.fit(C_norm["train"])
            else:
                raise ValueError(f"Unknown categorical NaN policy: {cat_nan_policy}")

            if any(x.any() for x in cat_nan_masks.values()):
                if imputer:
                    C = {k: imputer.transform(v) for k, v in C.items()}
                else:
                    for k, v in C.items():
                        cat_nan_indices = np.where(cat_nan_masks[k])
                        v[cat_nan_indices] = cat_new_value
                    for k,v in C_norm.items():
                        cat_nan_indices = np.where(cat_nan_masks_norm[k])
                        v[cat_nan_indices] = cat_new_value

        return N, N_norm, C, C_norm, num_new_value, imputer, cat_new_value

    def num_enc_process(
        self,
        N_data,
        num_policy,
        n_bins=2,
        y_train=None,
        is_regression=False,
        encoder=None,
    ):
        if N_data is None or num_policy == "none":
            return N_data, None
        else:
            return N_data, None

    def data_enc_process(
        self,
        N_data,
        N_data_norm,
        C_data,
        C_data_norm,
        cat_policy,
        y_train=None,
        ord_encoder=None,
        mode_values=None,
        cat_encoder=None,
    ):
        if C_data is None:
            return N_data, N_data_norm, C_data, None, None, None

        import sklearn.preprocessing

        unknown_value = np.iinfo("int64").max - 3
        if ord_encoder is None:
            ord_encoder = sklearn.preprocessing.OrdinalEncoder(
                handle_unknown="use_encoded_value",
                unknown_value=unknown_value,
                dtype="int64",
            ).fit(C_data_norm["train"])
        C_data = {k: ord_encoder.transform(v) for k, v in C_data.items()}
        C_data_norm = {k: ord_encoder.transform(v) for k, v in C_data_norm.items()}

        if mode_values is not None:
            assert "test" == self.split
            for column_idx in range(C_data["train"].shape[1]):
                C_data["train"][:, column_idx][
                    C_data["train"][:, column_idx] == unknown_value
                ] = mode_values[column_idx]
        elif "val" == self.split or "test" == self.split:
            mode_values = [
                (
                    np.argmax(np.bincount(column[column != unknown_value]))
                    if np.any(column == unknown_value)
                    else column[0]
                )
                for column in C_data_norm["train"].T
            ]
            for column_idx in range(C_data["train"].shape[1]):
                C_data["train"][:, column_idx][
                    C_data["train"][:, column_idx] == unknown_value
                ] = mode_values[column_idx]

        if cat_policy == "indices":
            return N_data, N_data_norm,C_data, ord_encoder, mode_values, cat_encoder
        elif cat_policy == "ordinal":
            cat_encoder = ord_encoder
        elif cat_policy == "ohe":
            if cat_encoder is None:
                cat_encoder = sklearn.preprocessing.OneHotEncoder(
                    handle_unknown="ignore", sparse_output=False, dtype="float64"
                )
                cat_encoder.fit(C_data_norm["train"])
            C_data = {k: cat_encoder.transform(v) for k, v in C_data.items()}
        else:
            pass

        if N_data is None:
            return C_data, C_data_norm, None, ord_encoder, mode_values, cat_encoder
        else:
            return (
                {x: np.hstack((N_data[x], C_data[x])) for x in N_data},
                {x: np.hstack((N_data_norm[x], C_data_norm[x])) for x in N_data_norm},
                None,
                ord_encoder,
                mode_values,
                cat_encoder,
            )

    def data_norm_process(self, N_data, N_data_norm, normalization, seed, normalizer=None):
        if N_data is None or normalization == "none":
            return N_data, None

        if normalizer is None:
            if normalization == "standard":
                from sklearn.preprocessing import StandardScaler    

                normalizer = StandardScaler().fit(N_data_norm["train"])
            elif normalization == "minmax":
                from sklearn.preprocessing import MinMaxScaler

                normalizer = MinMaxScaler().fit(N_data_norm["train"])
            else:
                return N_data, None

        N_data = {k: normalizer.transform(v) for k, v in N_data.items()}
        return N_data, normalizer

    def get_categories(self, X_cat):
        if X_cat is None:
            return None
        return [
            len(set(X_cat["train"][:, i].tolist()))
            for i in range(X_cat["train"].shape[1])
        ]

    def __len__(self):
        if self.dataset_name:
            return len(self.x_data)
        else:
            return len(self.x_data)

    def __getitem__(self, idx):
        if self.dataset_name:
            x = self.x_data[idx]
            y = self.y_data[idx]

            return core.ExamplebyteRLNumeric(
                x=x,
                y=float(y[0]) if len(y.shape) > 0 else float(y),
                y_max=self.train_y_info["y_max"],
                y_min=self.train_y_info["y_min"], 
            )
        else:
            x = self.x_data[idx]
            y = self.y_data[idx]

            return core.ExamplebyteRLNumeric(
                x=x,
                y=float(y[0]) if len(y.shape) > 0 else float(y),
                y_max=self.y_info["y_max"],
                y_min=self.y_info["y_min"], 
            )

    def get_item_with_metadata(self, idx):
        if self.dataset_name:
            x = self.x_data[idx]
            y = self.y_data[idx]

            return {
                "x": x,
                "y": float(y[0]) if len(y.shape) > 0 else float(y),
                "metadata": self.metadata,
                "x_raw": self.x_data[idx],
                "y_raw": y,
                "dataset_name": self.dataset_name,
                "categories": self.categories,
                "y_info": self.y_info,
            }
        else:
            current_idx = 0
            for dataset_name, dataset in self.datasets.items():
                dataset_size = len(dataset["x_data"])
                if current_idx + dataset_size > idx:
                    local_idx = idx - current_idx
                    x = dataset["x_data"][local_idx]
                    y = dataset["y_data"][local_idx]

                    return {
                        "x": x,
                        "y": float(y[0]) if len(y.shape) > 0 else float(y),
                        "metadata": dataset["metadata"],
                        "x_raw": dataset["x_data"][local_idx],
                        "y_raw": y,
                        "dataset_name": dataset_name,
                        "categories": dataset["categories"],
                        "y_info": dataset["y_info"],
                    }
                current_idx += dataset_size

            raise IndexError(f"Index {idx} out of range")

    def get_dataset_info(self) -> Dict[str, Any]:
        if self.dataset_name:
            return {
                "name": self.metadata["name"],
                "dimension": (
                    self.x_data.shape[1]
                    if hasattr(self, "x_data")
                    else self.metadata["dimension"]
                ),
                "split": self.split,
                "num_samples": len(self.x_data) if hasattr(self, "x_data") else 0,
                "num_feature_names": self.num_feature_names,
                "cat_feature_names": self.cat_feature_names,
                "categories": self.categories,
                "y_info": self.y_info if hasattr(self, "y_info") else None,
                "info": self.info,
            }
        else:
            return {
                "total_datasets": len(self.datasets),
                "split": self.split,
                "total_samples": len(self.x_data) if hasattr(self, "x_data") else 0,
                "datasets": self.dataset_info,
            }

    def get_dataset_names(self) -> List[str]:
        if self.dataset_name:
            return [self.dataset_name]
        else:
            return list(self.datasets.keys())


def load_regression_dataset(
    data_dir: str,
    split: str = "train",
    max_samples: Optional[int] = None,
    random_seed: int = 42,
    dataset_name: Optional[str] = None,
    compute_max_seq_len: bool = True,
    num_nan_policy: str = "mean",
    cat_nan_policy: str = "new",
    cat_policy: str = "indices",
    num_policy: str = "none",
    normalization: str = "standard",
    n_bins: int = 2,
    use_float: bool = False,
) -> Binary_fit_Dataset:
    return Binary_fit_Dataset(
        data_dir=data_dir,
        split=split,
        max_samples=max_samples,
        random_seed=random_seed,
        dataset_name=dataset_name,
        compute_max_seq_len=compute_max_seq_len,
        num_nan_policy=num_nan_policy,
        cat_nan_policy=cat_nan_policy,
        cat_policy=cat_policy,
        num_policy=num_policy,
        normalization=normalization,
        n_bins=n_bins,
        use_float=use_float,
    )
