import os
from pathlib import Path
from typing import Any, Dict, Optional
import torch
import numpy as np
import pandas as pd
from omegaconf import DictConfig

from haipr.predictor import BasePredictor
from haipr.data import HAIPRData

# Kermut imports
from kermut.data import (
    prepare_GP_inputs,
    prepare_GP_kwargs,
    split_inputs,
    standardize,
)
from kermut.gp import instantiate_gp, optimize_gp, predict
from kermut.gp._optimize_gp_svi import optimize_gp as optimize_gp_svi


class KermutPredictor(BasePredictor):
    """
    Predictor for the Kermut GP model, following the BasePredictor interface.
    Implements supervised variant effect prediction using composite kernel GPs.
    """

    def __init__(self, cfg: Optional[DictConfig] = None):
        """
        Initialize the KermutPredictor.
        Args:
            cfg (DictConfig, optional): Configuration for the Kermut model. If None, must be set in setup_data.
        """
        self.cfg = cfg
        self.data = None
        self.gp = None
        self.likelihood = None
        self.gp_inputs = None
        self.DMS_id = None
        self.target_seq = None
        self.trained = False

    def setup_data(self, data: HAIPRData) -> None:
        """
        Store the HAIPRData and extract config if not already set.
        """
        self.data = data
        if self.cfg is None:
            self.cfg = data.config
        # Optionally extract DMS_id and target_seq if present
        if hasattr(data, "DMS_id"):
            self.DMS_id = data.DMS_id
        if hasattr(data, "target_seq"):
            self.target_seq = data.target_seq

    def fit_model(
        self,
        dataset: HAIPRData,
        train_indices: Any,
        val_indices: Any,
    ) -> Dict[str, Any]:
        """
        Train the Kermut GP model on the given dataset and indices.
        Returns metrics and predictions for the validation set.
        """
        self.setup_data(dataset)
        cfg = self.cfg
        # Assume dataset.data is a DataFrame with a 'mutant' column and target_col
        DMS_id = self.DMS_id or getattr(cfg, "DMS_id", None) or "DMS"
        target_seq = (
            self.target_seq
            or getattr(cfg, "target_seq", None)
            or dataset.data[cfg.data.sequence_col][0]
        )

        # Prepare all tensors (full set)
        df, y, x_toks, x_embed, x_zero_shot = prepare_GP_inputs(cfg, DMS_id)
        gp_inputs = prepare_GP_kwargs(cfg, DMS_id, target_seq)

        # Split into train/val
        train_idx = np.zeros(len(df), dtype=bool)
        train_idx[train_indices] = True
        val_idx = np.zeros(len(df), dtype=bool)
        val_idx[val_indices] = True

        y_train, y_val = split_inputs(train_idx, val_idx, y)
        if cfg.data.standardize:
            y_train, y_val = standardize(y_train, y_val)

        x_toks_train, x_toks_val = split_inputs(train_idx, val_idx, x_toks)
        x_embed_train, x_embed_val = split_inputs(train_idx, val_idx, x_embed)
        x_zero_shot_train, x_zero_shot_val = split_inputs(
            train_idx, val_idx, x_zero_shot
        )

        train_inputs = (x_toks_train, x_embed_train, x_zero_shot_train)
        val_inputs = (x_toks_val, x_embed_val, x_zero_shot_val)
        train_targets = y_train
        val_targets = y_val

        # Train model
        if getattr(cfg.optim, "use_svi", False):
            gp, likelihood = optimize_gp_svi(
                train_inputs=train_inputs,
                train_targets=train_targets,
                kernel_cfg=cfg.kernel,
                gp_inputs=gp_inputs,
                n_inducing=cfg.optim.n_inducing,
                lr=cfg.optim.lr,
                n_steps=cfg.optim.n_steps,
                batch_size=cfg.optim.batch_size,
                use_zero_shot_mean=cfg.kernel.use_zero_shot,
                composite=True,
                progress_bar=cfg.optim.progress_bar,
            )
        else:
            gp, likelihood = instantiate_gp(
                cfg=cfg,
                train_inputs=train_inputs,
                train_targets=train_targets,
                gp_inputs=gp_inputs,
            )
            gp, likelihood = optimize_gp(
                gp=gp,
                likelihood=likelihood,
                train_inputs=train_inputs,
                train_targets=train_targets,
                lr=cfg.optim.lr,
                n_steps=cfg.optim.n_steps,
                progress_bar=cfg.optim.progress_bar,
            )
        self.gp = gp
        self.likelihood = likelihood
        self.gp_inputs = gp_inputs
        self.trained = True

        # Predict on validation set
        df_out = df.iloc[val_indices][["mutant"]].copy()
        df_out = df_out.assign(fold=0, y=np.nan, y_pred=np.nan, y_var=np.nan)
        test_idx = [True] * len(df_out)
        df_out = predict(
            gp=gp,
            likelihood=likelihood,
            test_inputs=val_inputs,
            test_targets=val_targets,
            test_fold=0,
            test_idx=test_idx,
            df_out=df_out,
        )
        # Compute metrics (e.g., Spearman, MSE)
        spearman = df_out["y"].corr(df_out["y_pred"], method="spearman")
        mse = ((df_out["y"] - df_out["y_pred"]) ** 2).mean()
        metrics = {"spearman": spearman, "mse": mse}
        predictions = {
            "indices": val_indices,
            "predictions": df_out["y_pred"].tolist(),
            "true_values": df_out["y"].tolist(),
            "variances": df_out["y_var"].tolist(),
        }
        return {"metrics": metrics, "predictions": predictions}

    def collate_fn(self, batch: List[Tuple[str, str, float]]) -> Dict[str, Any]:
        """
        Collate function for the Kermut GP model.
        """
        return batch

    def load_model(self, model_dir: str) -> None:
        """
        Load a trained Kermut GP model and likelihood from disk.
        Args:
            model_dir (str): Directory containing 'gp.pth' and 'likelihood.pth'.
        """
        cfg = self.cfg
        DMS_id = self.DMS_id or getattr(cfg, "DMS_id", None) or "DMS"
        target_seq = self.target_seq or getattr(cfg, "target_seq", None)
        gp_inputs = prepare_GP_kwargs(cfg, DMS_id, target_seq)
        # Dummy data for shapes
        df, y, x_toks, x_embed, x_zero_shot = prepare_GP_inputs(cfg, DMS_id)
        train_inputs = (x_toks, x_embed, x_zero_shot)
        train_targets = y
        gp, likelihood = instantiate_gp(
            cfg=cfg,
            train_inputs=train_inputs,
            train_targets=train_targets,
            gp_inputs=gp_inputs,
        )
        gp.load_state_dict(torch.load(os.path.join(model_dir, "gp.pth")))
        likelihood.load_state_dict(
            torch.load(os.path.join(model_dir, "likelihood.pth"))
        )
        self.gp = gp
        self.likelihood = likelihood
        self.gp_inputs = gp_inputs
        self.trained = True

    def predict(self, data: Any) -> Dict[str, Any]:
        """
        Make predictions on given data (expects a DataFrame or HAIPRData with 'mutant' and sequence columns).
        Returns a dictionary with predictions and variances.
        """
        if not self.trained:
            raise RuntimeError("Model must be trained or loaded before prediction.")
        cfg = self.cfg
        # Accepts either a DataFrame or HAIPRData
        if isinstance(data, HAIPRData):
            df = data.data
            DMS_id = self.DMS_id or getattr(cfg, "DMS_id", None) or "DMS"
        elif isinstance(data, pd.DataFrame):
            df = data
            DMS_id = self.DMS_id or getattr(cfg, "DMS_id", None) or "DMS"
        else:
            raise ValueError("Unsupported data type for prediction.")
        # Prepare features
        _, _, x_toks, x_embed, x_zero_shot = prepare_GP_inputs(cfg, DMS_id)
        test_inputs = (x_toks, x_embed, x_zero_shot)
        # Dummy targets (not used)
        test_targets = torch.zeros(x_toks.shape[0])
        df_out = df[["mutant"]].copy()
        df_out = df_out.assign(fold=0, y=np.nan, y_pred=np.nan, y_var=np.nan)
        test_idx = [True] * len(df_out)
        df_out = predict(
            gp=self.gp,
            likelihood=self.likelihood,
            test_inputs=test_inputs,
            test_targets=test_targets,
            test_fold=0,
            test_idx=test_idx,
            df_out=df_out,
        )
        return {
            "predictions": df_out["y_pred"].tolist(),
            "variances": df_out["y_var"].tolist(),
        }
