import logging
import os
import pickle
from typing import List, Dict

import numpy as np

from setting import CACHE_DIR
from utils import get_all_meta_data, get_all_step_trace

# Initialize logging
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)


class StepTraceDataset:
    """
    Dataset class for step traces.

    Attributes:
        samples (list): List of parsed samples.
        op2id (dict): Mapping from operation names to IDs.
        layer2id (dict): Mapping from layer names to IDs.
        start_min, start_max, end_min, end_max, dur_min, dur_max (float): Normalization bounds.
    """

    def __init__(
        self,
        all_step_trace: List[Dict],
        cache_path: str = os.path.join(CACHE_DIR, "step_trace_cache.pkl")
    ):
        """
        Initialize the dataset.

        Args:
            all_step_trace (List[Dict]): List of step trace dictionaries.
            cache_path (str): Path to cache file.
        """
        self.cache_path = cache_path

        if os.path.exists(self.cache_path):
            logger.info("Loading StepTraceDataset from cache: %s", self.cache_path)
            self._load_cache()
        else:
            logger.info("Building StepTraceDataset from raw traces.")
            self._build_dataset(all_step_trace)
            self._save_cache()

    def _build_dataset(self, all_step_trace: List[Dict]):
        """
        Build dataset from raw step traces.

        Args:
            all_step_trace (List[Dict]): List of step trace dictionaries.
        """
        self.samples = []
        self.op2id = {"<PAD>": 0}
        self.layer2id = {"<PAD>": 0}

        all_starts, all_ends, all_durations = [], [], []

        for step in all_step_trace:
            step_id = step.get("name", "")
            step_label = self._get_step_label(step)

            operations = []
            self._collect_operations(step, operations)

            for op in operations:
                if op["name"] not in self.op2id:
                    self.op2id[op["name"]] = len(self.op2id)
                if op["layer"] not in self.layer2id:
                    self.layer2id[op["layer"]] = len(self.layer2id)

                all_starts.append(op["start"])
                all_ends.append(op["end"])
                all_durations.append(op["duration"])

            self.samples.append(
                {
                    "step_id": step_id,
                    "step_label": step_label,
                    "operations": operations,
                }
            )

        self.start_min, self.start_max = np.min(all_starts), np.max(all_starts)
        self.end_min, self.end_max = np.min(all_ends), np.max(all_ends)
        self.dur_min, self.dur_max = np.min(all_durations), np.max(all_durations)

    def _save_cache(self):
        """Save dataset to cache file."""
        os.makedirs(os.path.dirname(self.cache_path), exist_ok=True)
        with open(self.cache_path, "wb") as f:
            pickle.dump(
                {
                    "samples": self.samples,
                    "op2id": self.op2id,
                    "layer2id": self.layer2id,
                    "start_min": self.start_min,
                    "start_max": self.start_max,
                    "end_min": self.end_min,
                    "end_max": self.end_max,
                    "dur_min": self.dur_min,
                    "dur_max": self.dur_max,
                },
                f,
            )
        logger.info("StepTraceDataset cached at %s", self.cache_path)

    def _load_cache(self):
        """Load dataset from cache file."""
        with open(self.cache_path, "rb") as f:
            cache = pickle.load(f)
        self.samples = cache["samples"]
        self.op2id = cache["op2id"]
        self.layer2id = cache["layer2id"]
        self.start_min = cache["start_min"]
        self.start_max = cache["start_max"]
        self.end_min = cache["end_min"]
        self.end_max = cache["end_max"]
        self.dur_min = cache["dur_min"]
        self.dur_max = cache["dur_max"]

    def _normalize(self, value, vmin, vmax):
        """
        Normalize value to [0, 1].

        Args:
            value (float): Value to normalize.
            vmin (float): Minimum value.
            vmax (float): Maximum value.

        Returns:
            float: Normalized value.
        """
        if vmax == vmin:
            return 0.0
        return (value - vmin) / (vmax - vmin)

    def _get_step_label(self, step: Dict) -> int:
        """
        Recursively get label for a step.

        Args:
            step (dict): Step dictionary.

        Returns:
            int: 1 if error exists, 0 otherwise.
        """
        if step.get("is_error", False):
            return 1
        for child in step.get("children", []):
            if self._get_step_label(child):
                return 1
        return 0

    def _collect_operations(self, node: Dict, operations: List[Dict]):
        """
        Recursively collect operations from a step node.

        Args:
            node (dict): Step node.
            operations (list): Accumulated operations list.
        """
        name = node.get("name", "")
        name = name.split(",")[0]
        name = name.split(" at ")[0]
        op_info = {
            "name": name,
            "start": node.get("start", 0),
            "end": node.get("end", 0),
            "duration": node.get("end", 0) - node.get("start", 0),
            "layer": node.get("layer", "<PAD>"),
            "is_error": int(node.get("is_error", False)),
        }
        operations.append(op_info)

        for child in node.get("children", []):
            self._collect_operations(child, operations)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx: int):
        sample = self.samples[idx]
        ops_encoded = []
        op_labels = []

        for op in sample["operations"]:
            ops_encoded.append(
                [
                    self.op2id[op["name"]],
                    self.layer2id[op["layer"]],
                    self._normalize(op["start"], self.start_min, self.start_max),
                    self._normalize(op["end"], self.end_min, self.end_max),
                    self._normalize(op["duration"], self.dur_min, self.dur_max),
                ]
            )
            op_labels.append(op["is_error"])

        ops_array = np.array(ops_encoded, dtype=np.float32)
        op_labels_array = np.array(op_labels, dtype=np.int64)

        return {
            "step_id": sample["step_id"],
            "ops_tensor": ops_array,
            "op_labels": op_labels_array,
            "step_label": np.array(sample["step_label"], dtype=np.int64),
        }


def collate_fn(batch):
    """
    Collate function for batching step trace samples.

    Args:
        batch (list): List of samples.

    Returns:
        dict: Batched step traces.
    """
    step_ids = [item["step_id"] for item in batch]
    step_labels = np.stack([item["step_label"] for item in batch])

    max_len = max(item["ops_tensor"].shape[0] for item in batch)
    feat_dim = batch[0]["ops_tensor"].shape[1]

    ops_padded = np.zeros((len(batch), max_len, feat_dim), dtype=np.float32)
    op_labels_padded = np.full((len(batch), max_len), -100, dtype=np.int64)

    for i, item in enumerate(batch):
        n = item["ops_tensor"].shape[0]
        ops_padded[i, :n, :] = item["ops_tensor"]
        op_labels_padded[i, :n] = item["op_labels"]

    return {
        "step_ids": step_ids,
        "ops_tensor": ops_padded,
        "op_labels": op_labels_padded,
        "step_labels": step_labels,
    }


def build_dataloader(task: str, batch_size: int, train_valid_test_rate: List[float]):
    """
    Build train/valid/test dataloaders.

    Args:
        task (str): 'vertical' or 'horizontal' task type.
        batch_size (int): Batch size.
        train_valid_test_rate (list): Ratios for train/valid/test split.

    Returns:
        tuple: Train, validation, and test dataloaders.
    """
    from math import floor
    from numpy.random import shuffle

    all_meta_data = get_all_meta_data()
    all_v_step_trace, all_h_step_trace = get_all_step_trace(all_meta_data)

    if task in ["v", "vertical"]:
        dataset = StepTraceDataset(
            all_v_step_trace,
            cache_path=os.path.join(CACHE_DIR, "v_step_trace_cache.pkl"),
        )
    elif task in ["h", "horizontal"]:
        dataset = StepTraceDataset(
            all_h_step_trace,
            cache_path=os.path.join(CACHE_DIR, "h_step_trace_cache.pkl"),
        )
    else:
        logger.error("Invalid task: %s", task)
        raise ValueError("task must be either 'vertical' or 'horizontal'")

    total_len = len(dataset)
    train_size = floor(total_len * train_valid_test_rate[0])
    valid_size = floor(total_len * train_valid_test_rate[1])
    test_size = total_len - train_size - valid_size

    indices = list(range(total_len))
    shuffle(indices)
    train_indices = indices[:train_size]
    valid_indices = indices[train_size : train_size + valid_size]
    test_indices = indices[train_size + valid_size :]

    class Subset:
        """Subset of a dataset."""

        def __init__(self, dataset, indices):
            self.dataset = dataset
            self.indices = indices

        def __len__(self):
            return len(self.indices)

        def __getitem__(self, idx):
            return self.dataset[self.indices[idx]]

    train_set = Subset(dataset, train_indices)
    valid_set = Subset(dataset, valid_indices)
    test_set = Subset(dataset, test_indices)

    class DataLoader:
        """Simple DataLoader implementation."""

        def __init__(self, dataset, batch_size, shuffle, collate_fn):
            self.dataset = dataset
            self.batch_size = batch_size
            self.shuffle = shuffle
            self.collate_fn = collate_fn
            self.indices = list(range(len(dataset)))

        def __len__(self):
            return (len(self.dataset) + self.batch_size - 1) // self.batch_size

        def __iter__(self):
            indices = self.indices.copy()
            if self.shuffle:
                shuffle(indices)
            for i in range(0, len(indices), self.batch_size):
                batch_indices = indices[i : i + self.batch_size]
                batch = [self.dataset[j] for j in batch_indices]
                yield self.collate_fn(batch)

    train_loader = DataLoader(
        train_set, batch_size=batch_size, shuffle=True, collate_fn=collate_fn
    )
    valid_loader = DataLoader(
        valid_set, batch_size=batch_size, shuffle=False, collate_fn=collate_fn
    )
    test_loader = DataLoader(
        test_set, batch_size=batch_size, shuffle=False, collate_fn=collate_fn
    )

    return train_loader, valid_loader, test_loader


if __name__ == "__main__":
    train_loader, valid_loader, test_loader = build_dataloader(
        "v", batch_size=32, train_valid_test_rate=[0.7, 0.2, 0.1]
    )

    logger.info("Train loader batches: %d", len(train_loader))
    logger.info("Valid loader batches: %d", len(valid_loader))
    logger.info("Test loader batches: %d", len(test_loader))

    for batch in train_loader:
        logger.debug("Sample batch step_ids: %s", batch["step_ids"][:2])
        logger.debug("Ops tensor shape: %s", batch["ops_tensor"].shape)
        logger.debug("Step labels shape: %s", batch["step_labels"].shape)
        break
