"""
MovieLens 100K Dataset Processing Module

This module provides a complete pipeline for:
1. Downloading and extracting the MovieLens 100k dataset
2. Preprocessing data for rating classification
3. Training a recommender classifier
4. Extracting probabilities and labels for calibration and test sets

Usage:
    from data.MovieLens_data import get_movielens_data
    cal_probs, cal_labels, test_probs, test_labels = get_movielens_data()
"""

import os
import zipfile
import random
from typing import Tuple, Dict, Optional, List

import numpy as np
import pandas as pd
import requests
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader


MOVIELENS_URL = "https://files.grouplens.org/datasets/movielens/ml-100k.zip"


def set_seed(seed: int = 42) -> None:
    """Set random seeds for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def download_movielens_100k(dest_path: str) -> None:
    """Download the MovieLens 100k zip if it does not exist."""
    if os.path.exists(dest_path):
        print("MovieLens 100k zip already exists. Skipping download.")
        return

    print("Downloading MovieLens 100k dataset...")
    response = requests.get(MOVIELENS_URL, stream=True, timeout=30)
    response.raise_for_status()
    with open(dest_path, "wb") as f:
        for chunk in response.iter_content(chunk_size=1024):
            if chunk:
                f.write(chunk)
    print("Download completed.")


def extract_movielens_100k(zip_path: str, extract_to: str) -> None:
    """Extract the MovieLens 100k zip to a target directory."""
    ml_folder = os.path.join(extract_to, "ml-100k")
    if os.path.exists(ml_folder):
        print("ml-100k folder already exists. Skipping extraction.")
        return

    print("Extracting ml-100k.zip...")
    with zipfile.ZipFile(zip_path, "r") as zip_ref:
        zip_ref.extractall(extract_to)
    print("Extraction completed.")


def load_ratings(path: str) -> pd.DataFrame:
    """Load ratings data."""
    return pd.read_csv(
        path,
        sep="\t",
        names=["user_id", "item_id", "rating", "timestamp"],
        engine="python"
    )


def load_users(path: str) -> pd.DataFrame:
    """Load user data."""
    user_cols = ["user_id", "age", "gender", "occupation", "zip_code"]
    return pd.read_csv(
        path,
        sep="|",
        names=user_cols,
        engine="python"
    )


def load_items(path: str) -> pd.DataFrame:
    """Load item (movie) data."""
    item_cols = [
        "item_id",
        "movie_title",
        "release_date",
        "video_release_date",
        "imdb_url",
        "unknown",
        "Action",
        "Adventure",
        "Animation",
        "Children's",
        "Comedy",
        "Crime",
        "Documentary",
        "Drama",
        "Fantasy",
        "Film-Noir",
        "Horror",
        "Musical",
        "Mystery",
        "Romance",
        "Sci-Fi",
        "Thriller",
        "War",
        "Western"
    ]
    return pd.read_csv(
        path,
        sep="|",
        names=item_cols,
        encoding="latin-1",
        engine="python"
    )


def split_data(
    df: pd.DataFrame,
    train_frac: float = 0.8,
    calib_frac: float = 0.1,
    test_frac: float = 0.1,
    seed: int = 42
) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    """Split DataFrame into train/calibration/test sets."""
    if abs(train_frac + calib_frac + test_frac - 1.0) > 1e-6:
        raise ValueError("train_frac + calib_frac + test_frac must sum to 1.0")

    df_shuffled = df.sample(frac=1, random_state=seed).reset_index(drop=True)
    n = len(df_shuffled)
    train_end = int(train_frac * n)
    calib_end = int((train_frac + calib_frac) * n)

    train_df = df_shuffled.iloc[:train_end]
    calib_df = df_shuffled.iloc[train_end:calib_end]
    test_df = df_shuffled.iloc[calib_end:]
    return train_df, calib_df, test_df


def preprocess_data_classification(
    merged_df: pd.DataFrame
) -> Tuple[pd.DataFrame, Dict[int, int], Dict[int, int]]:
    """
    1. Convert rating in [1..5] => class index [0..4].
    2. Re-index user_id and item_id.
    3. Encode gender, occupation, zip_code as integers.
    4. Extract release_year from release_date.
    5. Drop unused columns.
    """
    df = merged_df.copy()

    # 1) rating => [0..4]
    df["rating_class"] = df["rating"] - 1

    # 2) re-index user_id/item_id
    unique_users = df["user_id"].unique()
    user2idx = {u: i for i, u in enumerate(unique_users)}
    df["user_id_idx"] = df["user_id"].map(user2idx)

    unique_items = df["item_id"].unique()
    item2idx = {i: j for j, i in enumerate(unique_items)}
    df["item_id_idx"] = df["item_id"].map(item2idx)

    # 3) encode gender => {M:0, F:1}
    df["gender"] = df["gender"].map({"M": 0, "F": 1}).fillna(0).astype(int)

    occ_unique = df["occupation"].unique().tolist()
    occ_to_idx = {o: i for i, o in enumerate(occ_unique)}
    df["occupation"] = df["occupation"].map(occ_to_idx).fillna(0).astype(int)

    zip_unique = df["zip_code"].unique().tolist()
    zip_to_idx = {z: i for i, z in enumerate(zip_unique)}
    df["zip_code"] = df["zip_code"].map(zip_to_idx).fillna(0).astype(int)

    # 4) parse release_year
    def extract_year(date_str: Optional[str]) -> int:
        if pd.isna(date_str):
            return 0
        parts = str(date_str).split("-")
        if len(parts) == 3:
            try:
                return int(parts[2])
            except ValueError:
                return 0
        return 0

    df["release_year"] = df["release_date"].apply(extract_year)

    # drop unused columns
    df.drop(
        columns=["movie_title", "release_date", "video_release_date", "imdb_url"],
        inplace=True
    )

    return df, user2idx, item2idx


class MovieLensClassifDataset(Dataset):
    """PyTorch Dataset for MovieLens rating classification."""

    SIDE_COLUMNS: List[str] = [
        "age",
        "gender",
        "occupation",
        "zip_code",
        "timestamp",
        "release_year",
        "unknown",
        "Action",
        "Adventure",
        "Animation",
        "Children's",
        "Comedy",
        "Crime",
        "Documentary",
        "Drama",
        "Fantasy",
        "Film-Noir",
        "Horror",
        "Musical",
        "Mystery",
        "Romance",
        "Sci-Fi",
        "Thriller",
        "War",
        "Western"
    ]

    def __init__(self, df: pd.DataFrame):
        super().__init__()

        self.user_ids = torch.tensor(df["user_id_idx"].values, dtype=torch.long)
        self.item_ids = torch.tensor(df["item_id_idx"].values, dtype=torch.long)
        self.y = torch.tensor(df["rating_class"].values, dtype=torch.long)

        for col in self.SIDE_COLUMNS:
            if col not in df.columns:
                df[col] = 0

        self.side_info = torch.tensor(df[self.SIDE_COLUMNS].values, dtype=torch.float32)

    def __len__(self) -> int:
        return len(self.user_ids)

    def __getitem__(self, idx: int):
        return (self.user_ids[idx], self.item_ids[idx], self.side_info[idx]), self.y[idx]


class DeepRecommenderClassifier(nn.Module):
    """Embedding-based classifier for MovieLens ratings."""

    def __init__(
        self,
        num_users: int,
        num_items: int,
        user_emb_dim: int = 20,
        item_emb_dim: int = 20,
        side_in_dim: int = 26,
        side_hidden: int = 32,
        final_hidden: int = 32,
        num_classes: int = 5,
        dropout_rate: float = 0.3
    ):
        super().__init__()

        self.user_emb = nn.Embedding(num_users, user_emb_dim)
        self.item_emb = nn.Embedding(num_items, item_emb_dim)

        self.side_mlp = nn.Sequential(
            nn.Linear(side_in_dim, side_hidden),
            nn.BatchNorm1d(side_hidden),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(side_hidden, side_hidden),
            nn.BatchNorm1d(side_hidden),
            nn.ReLU(),
            nn.Dropout(dropout_rate)
        )

        comb_in_dim = user_emb_dim + item_emb_dim + side_hidden
        self.final_mlp = nn.Sequential(
            nn.Linear(comb_in_dim, final_hidden),
            nn.BatchNorm1d(final_hidden),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(final_hidden, num_classes)
        )

    def forward(self, user_ids, item_ids, side_info):
        ue = self.user_emb(user_ids)
        ie = self.item_emb(item_ids)
        se = self.side_mlp(side_info)
        x = torch.cat([ue, ie, se], dim=1)
        logits = self.final_mlp(x)
        return logits


class MovieLensTrainer:
    """Handles model creation, training, evaluation, and probability extraction."""

    def __init__(
        self,
        num_users: int,
        num_items: int,
        side_in_dim: int,
        device: Optional[str] = None,
        user_emb_dim: int = 32,
        item_emb_dim: int = 32,
        side_hidden: int = 64,
        final_hidden: int = 64,
        dropout_rate: float = 0.3,
        num_classes: int = 5,
        lr: float = 1e-3
    ):
        self.device = torch.device(device) if device else torch.device(
            "cuda" if torch.cuda.is_available() else "cpu"
        )
        self.model = DeepRecommenderClassifier(
            num_users=num_users,
            num_items=num_items,
            user_emb_dim=user_emb_dim,
            item_emb_dim=item_emb_dim,
            side_in_dim=side_in_dim,
            side_hidden=side_hidden,
            final_hidden=final_hidden,
            num_classes=num_classes,
            dropout_rate=dropout_rate
        ).to(self.device)
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)

    def train_one_epoch(self, loader: DataLoader) -> float:
        self.model.train()
        total_loss = 0.0
        total_samples = 0

        for (user_ids, item_ids, side_info), y_class in loader:
            user_ids = user_ids.to(self.device)
            item_ids = item_ids.to(self.device)
            side_info = side_info.to(self.device)
            y_class = y_class.to(self.device)

            self.optimizer.zero_grad()
            logits = self.model(user_ids, item_ids, side_info)
            loss = self.criterion(logits, y_class)
            loss.backward()
            self.optimizer.step()

            batch_size = y_class.size(0)
            total_loss += loss.item() * batch_size
            total_samples += batch_size

        return total_loss / total_samples if total_samples > 0 else 0.0

    def evaluate(self, loader: DataLoader) -> Tuple[float, float]:
        """Return average cross-entropy loss and accuracy."""
        self.model.eval()
        criterion = nn.CrossEntropyLoss(reduction="sum")
        total_loss = 0.0
        total_correct = 0
        total_samples = 0

        with torch.no_grad():
            for (user_ids, item_ids, side_info), y_class in loader:
                user_ids = user_ids.to(self.device)
                item_ids = item_ids.to(self.device)
                side_info = side_info.to(self.device)
                y_class = y_class.to(self.device)

                logits = self.model(user_ids, item_ids, side_info)
                loss = criterion(logits, y_class)

                preds = torch.argmax(logits, dim=1)
                total_correct += (preds == y_class).sum().item()
                total_loss += loss.item()
                total_samples += y_class.size(0)

        avg_loss = total_loss / total_samples if total_samples > 0 else 0.0
        accuracy = 100.0 * total_correct / total_samples if total_samples > 0 else 0.0
        return avg_loss, accuracy

    def train(
        self,
        train_loader: DataLoader,
        calib_loader: Optional[DataLoader] = None,
        num_epochs: int = 20
    ) -> None:
        print(f"\nStarting training for {num_epochs} epochs...")
        for epoch in range(1, num_epochs + 1):
            train_loss = self.train_one_epoch(train_loader)
            if calib_loader is not None:
                val_loss, val_acc = self.evaluate(calib_loader)
                print(
                    f"Epoch {epoch}/{num_epochs}: Train Loss={train_loss:.4f}, "
                    f"Val Loss={val_loss:.4f}, Val Acc={val_acc:.2f}%"
                )
            else:
                print(f"Epoch {epoch}/{num_epochs}: Train Loss={train_loss:.4f}")

    def extract_probabilities(self, data_loader: DataLoader) -> Tuple[np.ndarray, np.ndarray]:
        """Extract softmax probabilities and labels from a DataLoader."""
        self.model.eval()
        probs_list = []
        labels_list = []

        with torch.no_grad():
            for (user_ids, item_ids, side_info), y_class in data_loader:
                user_ids = user_ids.to(self.device)
                item_ids = item_ids.to(self.device)
                side_info = side_info.to(self.device)
                y_class = y_class.to(self.device)

                logits = self.model(user_ids, item_ids, side_info)
                probs = torch.softmax(logits, dim=1)
                probs_list.append(probs.cpu().numpy())
                labels_list.append(y_class.cpu().numpy())

        probs_arr = np.concatenate(probs_list, axis=0) if probs_list else np.empty((0, 5))
        labels_arr = np.concatenate(labels_list, axis=0) if labels_list else np.empty((0,))
        return probs_arr, labels_arr

    def save_model(self, path: str) -> None:
        """Save model state dict."""
        torch.save(self.model.state_dict(), path)
        print(f"Model saved to {path}")

    def load_model(self, path: str) -> None:
        """Load model state dict."""
        self.model.load_state_dict(torch.load(path, map_location=self.device))
        self.model.eval()
        print(f"Model loaded from {path}")


def get_movielens_data(
    data_dir: Optional[str] = None,
    data_root: str = "/path/to/your/data",
    zip_path: Optional[str] = "ml-100k.zip",
    train_model: bool = True,
    model_path: Optional[str] = None,
    num_epochs: int = 20,
    batch_size: int = 256,
    seed: int = 42,
    num_workers: int = 0,
    user_emb_dim: int = 32,
    item_emb_dim: int = 32,
    side_hidden: int = 64,
    final_hidden: int = 64,
    dropout_rate: float = 0.3,
    lr: float = 1e-3,
    device: Optional[str] = None
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """
    Main function to get MovieLens calibration/test probabilities and labels.

    Args:
        data_dir: Path to extracted ml-100k folder (default: data_root/ml-100k)
        data_root: Root folder for download/extraction when data_dir is not provided
        zip_path: Path to ml-100k.zip (relative to data_root if not absolute)
        train_model: Whether to train the model (False to load existing)
        model_path: Path to save/load model weights
        num_epochs: Number of training epochs
        batch_size: Batch size for DataLoaders
        seed: Random seed for reproducibility
        num_workers: Number of DataLoader workers
        user_emb_dim: User embedding dimension
        item_emb_dim: Item embedding dimension
        side_hidden: Hidden size for side-info MLP
        final_hidden: Hidden size for final MLP
        dropout_rate: Dropout rate
        lr: Learning rate
        device: Torch device string ('cuda' or 'cpu')

    Returns:
        Tuple of (cal_probs, cal_labels, test_probs, test_labels)
    """
    print("=" * 60)
    print("MovieLens 100K Processing Pipeline")
    print("=" * 60)

    set_seed(seed)

    if data_dir is None:
        data_root = os.path.abspath(data_root)
        data_dir = os.path.join(data_root, "ml-100k")
    else:
        data_dir = os.path.abspath(data_dir)
        data_root = os.path.dirname(data_dir)

    if zip_path is None:
        zip_path = os.path.join(data_root, "ml-100k.zip")
    elif not os.path.isabs(zip_path):
        zip_path = os.path.join(data_root, zip_path)

    # Step 1: Download & Extract (if needed)
    print("\n[1/4] Downloading and extracting data...")
    required_files = [
        os.path.join(data_dir, "u.data"),
        os.path.join(data_dir, "u.user"),
        os.path.join(data_dir, "u.item")
    ]
    data_ready = all(os.path.exists(p) for p in required_files)
    if data_ready:
        print("Found existing ml-100k data. Skipping download/extract.")
    else:
        download_movielens_100k(zip_path)
        extract_movielens_100k(zip_path, data_root)

    # Step 2: Load and preprocess data
    print("\n[2/4] Loading and preprocessing data...")
    ratings_df = load_ratings(os.path.join(data_dir, "u.data"))
    users_df = load_users(os.path.join(data_dir, "u.user"))
    items_df = load_items(os.path.join(data_dir, "u.item"))

    merged_df = pd.merge(ratings_df, users_df, on="user_id", how="left")
    merged_df = pd.merge(merged_df, items_df, on="item_id", how="left")

    df_classif, user2idx, item2idx = preprocess_data_classification(merged_df)
    train_df, calib_df, test_df = split_data(df_classif, 0.8, 0.1, 0.1, seed=seed)

    print(f"Train: {train_df.shape}, Calib: {calib_df.shape}, Test: {test_df.shape}")

    train_data = MovieLensClassifDataset(train_df)
    calib_data = MovieLensClassifDataset(calib_df)
    test_data = MovieLensClassifDataset(test_df)

    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    calib_loader = DataLoader(calib_data, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    side_in_dim = train_data.side_info.shape[1]
    print(f"Side info dimension: {side_in_dim}")

    # Step 3: Model setup
    print("\n[3/4] Setting up model...")
    trainer = MovieLensTrainer(
        num_users=len(user2idx),
        num_items=len(item2idx),
        side_in_dim=side_in_dim,
        device=device,
        user_emb_dim=user_emb_dim,
        item_emb_dim=item_emb_dim,
        side_hidden=side_hidden,
        final_hidden=final_hidden,
        dropout_rate=dropout_rate,
        lr=lr
    )

    # Step 4: Train or load model
    if train_model:
        print("\n[4/4] Training model...")
        trainer.train(train_loader, calib_loader=calib_loader, num_epochs=num_epochs)
        if model_path:
            trainer.save_model(model_path)
    else:
        print("\n[4/4] Loading pre-trained model...")
        if model_path is None:
            raise ValueError("model_path must be provided when train_model=False")
        trainer.load_model(model_path)

    print("\nExtracting probabilities...")
    cal_probs, cal_labels = trainer.extract_probabilities(calib_loader)
    test_probs, test_labels = trainer.extract_probabilities(test_loader)

    print(f"\nCalibration set: {cal_probs.shape}, {cal_labels.shape}")
    print(f"Test set: {test_probs.shape}, {test_labels.shape}")

    print("\n" + "=" * 60)
    print("Pipeline completed successfully!")
    print("=" * 60)

    return cal_probs, cal_labels, test_probs, test_labels


if __name__ == "__main__":
    cal_probs, cal_labels, test_probs, test_labels = get_movielens_data(
        num_epochs=20,
        model_path="models/movielens_model.pth"
    )

    print("\nData ready for conformal prediction!")
    print(f"Calibration probabilities shape: {cal_probs.shape}")
    print(f"Test probabilities shape: {test_probs.shape}")
