"""Dataset loading and processing utilities for molecular property data.

This module handles loading, normalizing, transforming, and merging
molecular property datasets into standardized formats.
"""

from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional
import json
import os
import pandas as pd
import numpy as np

from moltenflow.data.smiles_dataset import download_dataset
from moltenflow.data.data_utils import canonicalize_smiles


# Druglike property dataset configurations
DRUGLIKE_DATASET_CONFIGS = {
    "esol": {
        "smiles_col": "smiles",
        "target_col": "measured log solubility in mols per litre",
        "target_name": "esol",
    },
    "lipophilicity": {
        "smiles_col": "smiles",
        "target_col": "exp",
        "target_name": "lipo",
    },
    "freesolv": {
        "smiles_col": "smiles",
        "target_col": "expt",
        "target_name": "freesolv",
    },
    "zinc250k": {
        "smiles_col": "smiles",
        "target_cols": None,  # Properties computed from SMILES
        "computed_properties": ["qed", "sas", "plogp"],  # RDKit-computed properties
    },
}


@dataclass
class MolDataset:
    """Container for molecular dataset with SMILES and optional properties."""

    smiles: List[str]
    y: Optional[pd.DataFrame] = None
    c: Optional[pd.DataFrame] = None  # Conditional variables (e.g., temperature, pressure)


def load_csv_dataset(
    path: str, smiles_col: str, target_cols: list[str] | None = None
) -> MolDataset:
    """Load molecular dataset from CSV file.

    Args:
        path: Path to CSV file
        smiles_col: Name of SMILES column
        target_cols: Optional list of target property column names

    Returns:
        MolDataset with SMILES and optional target properties
    """
    df = pd.read_csv(path)
    smiles = df[smiles_col].astype(str).tolist()
    y = df[target_cols].copy() if target_cols else None
    return MolDataset(smiles=smiles, y=y)


def load_and_normalize(source_cfg: Dict[str, Any], dataset_name: str) -> pd.DataFrame:
    """Load CSV and normalize column names to standard format.

    Renames columns to: smiles, temperature, pressure, <target_name>

    Args:
        source_cfg: Configuration dict with keys:
            - path: Path to CSV file
            - smiles_col: Name of SMILES column
            - target_col: Name of target property column
            - cond_cols: List of condition column names
            - default_pressure: Optional default pressure value if missing
        dataset_name: Name of dataset (used for logging)

    Returns:
        DataFrame with normalized column names
    """
    # Load CSV
    df = pd.read_csv(source_cfg["path"])

    # Build column mapping
    col_mapping = {
        source_cfg["smiles_col"]: "smiles",
    }

    # Map condition columns
    cond_cols = source_cfg.get("cond_cols", [])
    for cond_col in cond_cols:
        if cond_col in df.columns:
            col_mapping[cond_col] = cond_col  # Keep temperature, pressure as-is

    # Map target column
    target_col = source_cfg["target_col"]
    col_mapping[target_col] = target_col

    # Select and rename columns
    df = df[list(col_mapping.keys())].rename(columns=col_mapping)

    # Add default pressure if missing
    if "pressure" not in df.columns and "default_pressure" in source_cfg:
        df["pressure"] = source_cfg["default_pressure"]

    return df


def apply_transforms(df: pd.DataFrame, transforms_cfg: Dict[str, Optional[str]]) -> pd.DataFrame:
    """Apply transformations to target columns.

    Args:
        df: DataFrame with target columns
        transforms_cfg: Dict mapping column names to transform types
            Supported transforms: "log", None

    Returns:
        DataFrame with transformed columns
    """
    df = df.copy()

    for col, transform in transforms_cfg.items():
        if col not in df.columns:
            continue

        if transform == "log":
            # Apply natural log transform
            df[col] = np.log(df[col])
        elif transform is None:
            # No transform
            pass
        else:
            raise ValueError(f"Unknown transform type: {transform}")

    return df


def merge_datasets(
    df_co2: pd.DataFrame,
    df_visc: pd.DataFrame,
    temp_tolerance: float = 1.0,
    pressure_tolerance: float = 0.01,
) -> pd.DataFrame:
    """Merge CO2 and viscosity datasets with tolerance-based matching.

    Uses vectorized operations within each SMILES group for efficient matching.
    For each CO2 entry, finds the best matching viscosity entry (if any) based on
    minimum Euclidean distance in normalized condition space.

    Args:
        df_co2: CO2 solubility dataframe with columns:
            [smiles, temperature, pressure, x_CO2]
        df_visc: Viscosity dataframe with columns:
            [smiles, temperature, pressure, viscosity]
        temp_tolerance: Temperature matching tolerance in Kelvin (default: ±1.0 K)
        pressure_tolerance: Pressure matching tolerance in bar (default: ±0.1 bar)

    Returns:
        Merged dataframe with columns:
            [smiles, temperature, pressure, x_CO2, viscosity]
    """
    merged_parts = []

    # Group both dataframes by SMILES for efficient matching
    co2_groups = {name: group for name, group in df_co2.groupby("smiles")}
    visc_groups = {name: group for name, group in df_visc.groupby("smiles")}

    # Get all unique SMILES across both datasets
    all_smiles = set(co2_groups.keys()).union(set(visc_groups.keys()))

    # Track which viscosity indices have been matched (across all SMILES)
    visc_matched_indices = set()

    for smiles in all_smiles:
        df_co2_smiles = co2_groups.get(smiles)
        df_visc_smiles = visc_groups.get(smiles)

        # Case 1: Only viscosity data for this SMILES
        if df_co2_smiles is None:
            df_part = df_visc_smiles.copy()
            df_part["x_CO2"] = np.nan
            merged_parts.append(
                df_part[["smiles", "temperature", "pressure", "x_CO2", "viscosity"]]
            )
            continue

        # Case 2: Only CO2 data for this SMILES
        if df_visc_smiles is None:
            df_part = df_co2_smiles.copy()
            df_part["viscosity"] = np.nan
            merged_parts.append(
                df_part[["smiles", "temperature", "pressure", "x_CO2", "viscosity"]]
            )
            continue

        # Case 3: Both present - vectorized matching within this SMILES group
        # Extract condition arrays
        temp_co2 = df_co2_smiles["temperature"].values[:, np.newaxis]  # Shape: (n_co2, 1)
        temp_visc = df_visc_smiles["temperature"].values[np.newaxis, :]  # Shape: (1, n_visc)
        pres_co2 = df_co2_smiles["pressure"].values[:, np.newaxis]
        pres_visc = df_visc_smiles["pressure"].values[np.newaxis, :]

        # Compute pairwise normalized distances using broadcasting
        # Handle zero tolerance (exact match required)
        if temp_tolerance > 0:
            temp_dist = np.abs(temp_co2 - temp_visc) / temp_tolerance
        else:
            # Zero tolerance means exact match required
            temp_dist = np.where(temp_co2 == temp_visc, 0.0, np.inf)

        if pressure_tolerance > 0:
            pres_dist = np.abs(pres_co2 - pres_visc) / pressure_tolerance
        else:
            # Zero tolerance means exact match required
            pres_dist = np.where(pres_co2 == pres_visc, 0.0, np.inf)

        distances = np.sqrt(temp_dist**2 + pres_dist**2)

        # Mask for entries within tolerance (sqrt(2) is the diagonal distance when both are at tolerance)
        within_tolerance = distances <= np.sqrt(2)

        # Match each CO2 row to best viscosity row (if any)
        for i, (idx_co2, row_co2) in enumerate(df_co2_smiles.iterrows()):
            matches = np.where(within_tolerance[i])[0]

            if len(matches) == 0:
                # No viscosity match within tolerance
                merged_parts.append(
                    {
                        "smiles": smiles,
                        "temperature": row_co2["temperature"],
                        "pressure": row_co2["pressure"],
                        "x_CO2": row_co2["x_CO2"],
                        "viscosity": np.nan,
                    }
                )
            else:
                # Find best match (minimum distance)
                best_local_idx = matches[np.argmin(distances[i, matches])]
                visc_row = df_visc_smiles.iloc[best_local_idx]
                visc_global_idx = df_visc_smiles.index[best_local_idx]
                visc_matched_indices.add(visc_global_idx)

                # Use CO2 conditions as reference
                merged_parts.append(
                    {
                        "smiles": smiles,
                        "temperature": row_co2["temperature"],
                        "pressure": row_co2["pressure"],
                        "x_CO2": row_co2["x_CO2"],
                        "viscosity": visc_row["viscosity"],
                    }
                )

        # Add unmatched viscosity rows for this SMILES
        unmatched_visc = df_visc_smiles[~df_visc_smiles.index.isin(visc_matched_indices)]
        if len(unmatched_visc) > 0:
            df_part = unmatched_visc.copy()
            df_part["x_CO2"] = np.nan
            merged_parts.append(
                df_part[["smiles", "temperature", "pressure", "x_CO2", "viscosity"]]
            )

    # Concatenate all parts into final dataframe
    result_parts = []
    for part in merged_parts:
        if isinstance(part, dict):
            result_parts.append(pd.DataFrame([part]))
        else:
            result_parts.append(part)

    df_merged = pd.concat(result_parts, ignore_index=True)

    return df_merged


def save_processed(df: pd.DataFrame, path: str | Path) -> None:
    """Save processed dataframe as parquet.

    Args:
        df: Processed dataframe
        path: Output path (will create parent directories)
    """
    path = Path(path)
    path.parent.mkdir(parents=True, exist_ok=True)

    # Save as parquet for efficient storage
    df.to_parquet(path, index=False, engine="pyarrow")

    print(f"Saved {len(df)} rows to {path}")
    print(f"  Columns: {list(df.columns)}")
    print(f"  File size: {path.stat().st_size / 1024:.1f} KB")


def load_processed_dataset(
    path: str,
    smiles_col: str,
    target_cols: list[str],
    cond_cols: list[str] | None = None,
    drop_nan: bool = True,
) -> tuple[MolDataset, int]:
    """Load a processed dataset from a parquet file.

    Args:
        path: Path to parquet file.
        smiles_col: Name of the SMILES column.
        target_cols: List of target property column names.
        cond_cols: List of conditional variable column names (optional).
        drop_nan: If True, filter out rows with NaN in ANY target column.
                 If False, keep all rows (for masked training with NaN values).

    Returns:
        Tuple of (MolDataset, number of dropped rows)
    """
    df = pd.read_parquet(path)
    initial_rows = len(df)

    if drop_nan:
        # Drop rows with any NaN in target columns
        df = df.dropna(subset=target_cols)
        dropped_rows = initial_rows - len(df)
    else:
        # Keep all rows (NaN values will be handled by masked loss)
        dropped_rows = 0

    smiles = df[smiles_col].astype(str).tolist()
    y = df[target_cols].copy()
    c = df[cond_cols].copy() if cond_cols else None

    return MolDataset(smiles=smiles, y=y, c=c), dropped_rows


def load_property_dataset(
    name: str,
    data_dir: str = "data/raw",
    processed_dir: str = "data/processed",
    drop_fragments: bool = True,
    deduplicate: bool = True,
    force_reprocess: bool = False,
) -> pd.DataFrame:
    """Load and preprocess a druglike property dataset.

    Processing steps:
    1. Download CSV from URL if not cached
    2. Canonicalize SMILES with RDKit
    3. Drop invalid SMILES (RDKit parse failure)
    4. Optionally drop multi-fragment SMILES (containing ".")
    5. Optionally deduplicate by canonical SMILES (keep first)
    6. Rename target column to standardized name
    7. Save processed parquet and metadata JSON

    Args:
        name: Dataset name ("esol", "lipophilicity", "freesolv")
        data_dir: Directory for raw CSV downloads
        processed_dir: Directory for processed parquet output
        drop_fragments: If True, drop SMILES containing "."
        deduplicate: If True, deduplicate by canonical SMILES (keep first)
        force_reprocess: If True, reprocess even if parquet exists

    Returns:
        DataFrame with columns ["smiles", "<target_name>"]

    Raises:
        ValueError: If dataset name not recognized
    """
    if name not in DRUGLIKE_DATASET_CONFIGS:
        raise ValueError(
            f"Unknown dataset '{name}'. Available: {list(DRUGLIKE_DATASET_CONFIGS.keys())}"
        )

    config = DRUGLIKE_DATASET_CONFIGS[name]
    target_name = config["target_name"]

    # Check if processed data already exists
    processed_path = Path(processed_dir) / f"{name}.parquet"
    if processed_path.exists() and not force_reprocess:
        return pd.read_parquet(processed_path)

    # Download raw CSV if needed
    raw_csv_path = Path(data_dir) / f"{name}.csv"
    if not raw_csv_path.exists():
        os.makedirs(data_dir, exist_ok=True)
        download_dataset(name, str(raw_csv_path))

    # Load raw data
    df = pd.read_csv(raw_csv_path)
    initial_count = len(df)

    # Extract SMILES and target columns
    smiles_col = config["smiles_col"]
    target_col = config["target_col"]

    if smiles_col not in df.columns:
        raise ValueError(f"SMILES column '{smiles_col}' not found in {name} dataset")
    if target_col not in df.columns:
        raise ValueError(f"Target column '{target_col}' not found in {name} dataset")

    # Create working dataframe with only needed columns
    df_work = pd.DataFrame({"smiles": df[smiles_col].astype(str), target_name: df[target_col]})

    # Drop rows with NaN targets
    df_work = df_work.dropna(subset=[target_name])

    # Canonicalize SMILES and filter invalid
    invalid_count = 0
    canonical_smiles = []
    valid_indices = []

    for idx, smi in enumerate(df_work["smiles"]):
        canon = canonicalize_smiles(smi)
        if canon is None:
            invalid_count += 1
        else:
            canonical_smiles.append(canon)
            valid_indices.append(idx)

    df_work = df_work.iloc[valid_indices].copy()
    df_work["smiles"] = canonical_smiles

    # Drop fragments if requested
    fragment_count = 0
    if drop_fragments:
        mask = ~df_work["smiles"].str.contains(".", regex=False)
        fragment_count = (~mask).sum()
        df_work = df_work[mask]

    # Deduplicate if requested
    duplicate_count = 0
    if deduplicate:
        before_dedup = len(df_work)
        df_work = df_work.drop_duplicates(subset=["smiles"], keep="first")
        duplicate_count = before_dedup - len(df_work)

    final_count = len(df_work)

    # Save processed data
    os.makedirs(processed_dir, exist_ok=True)
    df_work.to_parquet(processed_path, index=False, engine="pyarrow")

    # Save metadata
    metadata = {
        "dataset": name,
        "initial_count": int(initial_count),
        "invalid_smiles": int(invalid_count),
        "fragments_dropped": int(fragment_count) if drop_fragments else 0,
        "duplicates_dropped": int(duplicate_count) if deduplicate else 0,
        "final_count": int(final_count),
        "target_column": target_name,
        "processing_options": {
            "drop_fragments": drop_fragments,
            "deduplicate": deduplicate,
        },
    }

    metadata_path = Path(processed_dir) / f"{name}_metadata.json"
    with open(metadata_path, "w") as f:
        json.dump(metadata, f, indent=2)

    print(f"Processed {name} dataset:")
    print(f"  Initial: {initial_count} rows")
    print(f"  Invalid SMILES: {invalid_count}")
    if drop_fragments:
        print(f"  Fragments dropped: {fragment_count}")
    if deduplicate:
        print(f"  Duplicates dropped: {duplicate_count}")
    print(f"  Final: {final_count} rows")
    print(f"  Saved to: {processed_path}")

    return df_work


def merge_druglike_datasets(
    datasets: dict[str, pd.DataFrame],
) -> pd.DataFrame:
    """Merge multiple druglike property datasets on SMILES.

    Performs an outer join on SMILES to create a multi-property dataset.
    Molecules present in only some datasets will have NaN for missing properties.

    Args:
        datasets: Dict mapping property names to DataFrames.
                 Each DataFrame should have columns ["smiles", "<property_name>"]

    Returns:
        Merged DataFrame with columns ["smiles", property1, property2, ...]
        where properties correspond to the keys in the datasets dict.

    Example:
        >>> df_esol = pd.DataFrame({"smiles": ["CCO", "c1ccccc1"], "esol": [-0.77, -0.13]})
        >>> df_lipo = pd.DataFrame({"smiles": ["CCO", "CCCO"], "lipo": [0.28, 0.53]})
        >>> merged = merge_druglike_datasets({"esol": df_esol, "lipo": df_lipo})
        >>> # Result: smiles=["CCO", "c1ccccc1", "CCCO"],
        >>> #         esol=[-0.77, -0.13, NaN], lipo=[0.28, NaN, 0.53]
    """
    if not datasets:
        raise ValueError("At least one dataset required for merging")

    # Start with the first dataset
    property_names = list(datasets.keys())
    result = datasets[property_names[0]].copy()

    # Sequentially merge remaining datasets on SMILES (outer join)
    for prop_name in property_names[1:]:
        df_next = datasets[prop_name]
        result = result.merge(df_next, on="smiles", how="outer")

    # Reorder columns: smiles first, then properties in order
    cols = ["smiles"] + property_names
    result = result[cols]

    return result


def load_zinc250k_with_properties(
    data_dir: str = "data/raw",
    processed_dir: str = "data/processed",
    properties: list[str] | None = None,
    drop_fragments: bool = True,
    deduplicate: bool = True,
    force_reprocess: bool = False,
    max_molecules: int | None = None,
) -> pd.DataFrame:
    """Load ZINC250K dataset with RDKit-computed properties.

    Unlike other datasets that have pre-computed properties in the CSV,
    ZINC250K requires computing properties (QED, SAS, pLogP) from SMILES.

    Processing steps:
    1. Download ZINC250K CSV if not cached
    2. Canonicalize SMILES with RDKit
    3. Drop invalid SMILES (RDKit parse failure)
    4. Optionally drop multi-fragment SMILES
    5. Optionally deduplicate by canonical SMILES
    6. Compute requested molecular properties
    7. Save processed parquet and metadata

    Args:
        data_dir: Directory for raw CSV downloads
        processed_dir: Directory for processed parquet output
        properties: List of properties to compute. Default: ["qed", "sas", "plogp"]
        drop_fragments: If True, drop SMILES containing "."
        deduplicate: If True, deduplicate by canonical SMILES
        force_reprocess: If True, reprocess even if parquet exists
        max_molecules: If set, limit to first N molecules (for testing)

    Returns:
        DataFrame with columns ["smiles", <property1>, <property2>, ...]

    Example:
        >>> df = load_zinc250k_with_properties(properties=["qed", "sas", "plogp"])
        >>> df.columns
        Index(['smiles', 'qed', 'sas', 'plogp'], dtype='object')
    """
    from moltenflow.data.smiles_dataset import download_dataset
    from moltenflow.data.properties import compute_properties_batch, has_sascorer

    if properties is None:
        properties = ["qed", "sas", "plogp"]

    # Check if SAS-dependent properties are requested
    sas_properties = {"sas", "plogp"}
    if sas_properties & set(properties) and not has_sascorer():
        raise ValueError(
            f"Properties {sas_properties & set(properties)} require SA_Score. "
            "Install RDKit with Contrib support."
        )

    # Build cache filename based on properties
    props_suffix = "_".join(sorted(properties))
    processed_path = Path(processed_dir) / f"zinc250k_{props_suffix}.parquet"

    if processed_path.exists() and not force_reprocess:
        df = pd.read_parquet(processed_path)
        if max_molecules is not None:
            df = df.head(max_molecules)
        return df

    # Download raw CSV if needed
    raw_csv_path = Path(data_dir) / "zinc250k.csv"
    if not raw_csv_path.exists():
        os.makedirs(data_dir, exist_ok=True)
        download_dataset("zinc250k", str(raw_csv_path))

    # Load raw data
    df = pd.read_csv(raw_csv_path)
    initial_count = len(df)

    # The ZINC250K CSV has 'smiles' column
    smiles_col = "smiles"
    if smiles_col not in df.columns:
        # Try alternative column names
        for alt in ["SMILES", "Smiles", "smi"]:
            if alt in df.columns:
                smiles_col = alt
                break
        else:
            raise ValueError(f"SMILES column not found. Available: {list(df.columns)}")

    smiles_raw = df[smiles_col].astype(str).tolist()

    # Canonicalize SMILES and filter invalid
    print(f"Processing ZINC250K: {initial_count} molecules")
    print("  Canonicalizing SMILES...")

    canonical_smiles = []
    invalid_count = 0

    for smi in smiles_raw:
        canon = canonicalize_smiles(smi)
        if canon is None:
            invalid_count += 1
        else:
            canonical_smiles.append(canon)

    # Drop fragments if requested
    fragment_count = 0
    if drop_fragments:
        filtered_smiles = []
        for smi in canonical_smiles:
            if "." in smi:
                fragment_count += 1
            else:
                filtered_smiles.append(smi)
        canonical_smiles = filtered_smiles

    # Deduplicate if requested
    duplicate_count = 0
    if deduplicate:
        before_dedup = len(canonical_smiles)
        canonical_smiles = list(dict.fromkeys(canonical_smiles))  # Preserves order
        duplicate_count = before_dedup - len(canonical_smiles)

    print(f"  Invalid SMILES: {invalid_count}")
    if drop_fragments:
        print(f"  Fragments dropped: {fragment_count}")
    if deduplicate:
        print(f"  Duplicates dropped: {duplicate_count}")
    print(f"  Valid molecules: {len(canonical_smiles)}")

    # Compute properties
    print(f"  Computing properties: {properties}")
    props_array, valid_mask = compute_properties_batch(
        canonical_smiles, properties, return_valid_mask=True
    )

    # Filter to only molecules where all properties computed successfully
    valid_smiles = [s for s, v in zip(canonical_smiles, valid_mask) if v]
    valid_props = props_array[valid_mask]

    failed_props = len(canonical_smiles) - len(valid_smiles)
    if failed_props > 0:
        print(f"  Property computation failed for {failed_props} molecules")

    # Create DataFrame
    df_result = pd.DataFrame({"smiles": valid_smiles})
    for i, prop_name in enumerate(properties):
        df_result[prop_name] = valid_props[:, i]

    final_count = len(df_result)
    print(f"  Final: {final_count} molecules")

    # Save processed data
    os.makedirs(processed_dir, exist_ok=True)
    df_result.to_parquet(processed_path, index=False, engine="pyarrow")

    # Save metadata
    metadata = {
        "dataset": "zinc250k",
        "initial_count": int(initial_count),
        "invalid_smiles": int(invalid_count),
        "fragments_dropped": int(fragment_count) if drop_fragments else 0,
        "duplicates_dropped": int(duplicate_count) if deduplicate else 0,
        "property_failures": int(failed_props),
        "final_count": int(final_count),
        "properties": properties,
        "processing_options": {
            "drop_fragments": drop_fragments,
            "deduplicate": deduplicate,
        },
    }

    metadata_path = Path(processed_dir) / f"zinc250k_{props_suffix}_metadata.json"
    with open(metadata_path, "w") as f:
        json.dump(metadata, f, indent=2)

    print(f"  Saved to: {processed_path}")

    if max_molecules is not None:
        df_result = df_result.head(max_molecules)

    return df_result
