import logging
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from esm.models.esm3 import ESM3, ESMProtein, ESMProteinTensor, ESMOutput
from esm.sdk.api import ProteinComplex
from esm.utils import encoding
from typing import Any, Dict, Optional, List

from tqdm import tqdm
from haipr.models.module import HAIPRModule
from haipr.utils import loss_funcs
import lightning.pytorch as pl
from torch.utils.data import Subset, DataLoader
from haipr.data import HAIPRData
from typing import List
from biotite.structure.io.pdb import PDBFile
from biotite.structure import sasa
from peft import LoraConfig, get_peft_model
from omegaconf import DictConfig

logging.basicConfig()
logger = logging.getLogger(__name__)
logger.setLevel(logging.ERROR)


def create_filtered_protein_complex(
    pdb_path: str, chain_list: List[str], id: str = None
) -> ProteinComplex:
    """
    Create a ProteinComplex from a PDB file, but only include specific chains.

    Args:
        pdb_path: Path to the PDB file
        chain_list: List of chain IDs to include
        id: Optional ID for the complex

    Returns:
        ProteinComplex containing only the specified chains
    """
    from esm.utils.structure.protein_chain import ProteinChain

    chains = []
    for chain_id in chain_list:
        try:
            # Use the built-in ProteinChain.from_pdb method to load specific chains
            chain = ProteinChain.from_pdb(pdb_path, chain_id=chain_id, id=id)
            chains.append(chain)
            logger.debug(
                f"Successfully loaded chain {chain_id} with {len(chain.sequence)} residues"
            )
        except Exception as e:
            logger.warning(f"Failed to load chain {chain_id} from PDB: {e}")
            continue

    if not chains:
        raise ValueError(
            f"No valid chains found for chain_list {chain_list} in PDB file {pdb_path}"
        )

    return ProteinComplex.from_chains(chains)


class ESM3Dataset(Dataset):
    def __init__(self, inputs, labels):
        self.inputs = inputs
        self.labels = labels

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        inputs = {}
        for k, v in self.inputs.items():
            if v is not None:  # Only process non-None values
                inputs[k] = v[idx]
        labels = self.labels[idx]
        logger.debug(f"Inputs: {inputs}")
        logger.debug(f"Labels: {labels}")
        return {"inputs": inputs, "labels": labels}


class ESM3Predictor(HAIPRModule):
    """ESM3-specific predictor implementation using PyTorch Lightning."""

    def __init__(
        self,
        model_name: str = "esm3_sm_open_v1",
        num_classes: int = 0,
        prediction_head: Optional[nn.Module] = None,
        learning_rate: float = 1e-4,
        weight_decay: float = 0.01,
        batch_size: int = 1,
        **kwargs,
    ):
        logger.info(f"Initializing ESM3Predictor with kwargs: {kwargs}")
        logger.info(
            f"Initializing ESM3Predictor with model_name: {model_name}")
        logger.info(
            f"Initializing ESM3Predictor with num_classes: {num_classes}")
        logger.info(
            f"Initializing ESM3Predictor with prediction_head: {prediction_head}"
        )
        logger.info(
            f"Initializing ESM3Predictor with learning_rate: {learning_rate}")
        logger.info(
            f"Initializing ESM3Predictor with weight_decay: {weight_decay}")
        logger.info(
            f"Initializing ESM3Predictor with batch_size: {batch_size}")

        # Initialize base model
        model = ESM3.from_pretrained(
            model_name
        )  # handle device automatically in trainer

        # Force model to use FP32 to avoid BFloat16 compatibility issues
        # model = model.float()

        # Fix tokenizers if needed
        self._fix_tokenizers_if_needed(model)

        criterion = loss_funcs.get(kwargs.get("loss", "mse"))

        # Initialize HAIPRModule with the model
        logger.debug("Initializing HAIPRModule with model")
        super().__init__(
            model=model,
            criterion=criterion,
            num_classes=num_classes,
            learning_rate=learning_rate,
            weight_decay=weight_decay,
            **kwargs,
        )

        # Store additional attributes
        self.model_name = model_name
        self.prediction_head = prediction_head
        self.batch_size = batch_size
        self.use_structure = kwargs.get("use_structure", False)

    def _fix_tokenizers_if_needed(self, model):
        """Fix tokenizers if mask_token is None but exists in vocabulary (Singularity issue)."""
        # Fix main sequence tokenizer
        if (
            model.tokenizers.sequence.mask_token is None
            and "<mask>" in model.tokenizers.sequence.vocab
        ):
            logger.info(
                "Fixing ESM3 tokenizer mask_token for Singularity compatibility"
            )
            model.tokenizers.sequence.mask_token = "<mask>"
            model.tokenizers.sequence.mask_token_id = model.tokenizers.sequence.vocab[
                "<mask>"
            ]

        # Fix other tokenizers in the model
        for track in [
            "sequence",
            "structure",
            "secondary_structure",
            "sasa",
            "function",
            "residue_annotations",
        ]:
            if hasattr(model.tokenizers, track):
                tokenizer = getattr(model.tokenizers, track)
                if (
                    hasattr(tokenizer, "mask_token")
                    and tokenizer.mask_token is None
                    and hasattr(tokenizer, "vocab")
                    and "<mask>" in tokenizer.vocab
                ):
                    tokenizer.mask_token = "<mask>"
                    tokenizer.mask_token_id = tokenizer.vocab["<mask>"]
                    logger.info(f"Fixed {track} tokenizer mask_token")

                # Check that required tokens are set for empty tensor creation
                if (
                    hasattr(tokenizer, "bos_token")
                    and tokenizer.bos_token is None
                    and hasattr(tokenizer, "vocab")
                    and "<cls>" in tokenizer.vocab
                ):
                    tokenizer.bos_token = "<cls>"
                    tokenizer.bos_token_id = tokenizer.vocab["<cls>"]
                    logger.info(f"Fixed {track} tokenizer bos_token")

                if (
                    hasattr(tokenizer, "eos_token")
                    and tokenizer.eos_token is None
                    and hasattr(tokenizer, "vocab")
                    and "<eos>" in tokenizer.vocab
                ):
                    tokenizer.eos_token = "<eos>"
                    tokenizer.eos_token_id = tokenizer.vocab["<eos>"]
                    logger.info(f"Fixed {track} tokenizer eos_token")

    def _safe_encode(self, protein: ESMProtein) -> ESMProteinTensor:
        """
        Custom encoding function that works around the None mask_token issue in some environments.
        """
        if protein.sequence is None:
            raise ValueError("sequence is required for encoding")

        # First check if we can use the standard encoding (if tokenizer is fixed)
        if self.model.tokenizers.sequence.mask_token is not None:
            try:
                return self.model.encode(protein)
            except Exception as e:
                logger.warning(
                    f"Standard encoding failed: {e}. Falling back to custom encoding."
                )
        # Fallback custom encoding approach
        sequence = protein.sequence

        # Use the tokenizer directly
        sequence_tokens = self.model.tokenizers.sequence.encode(
            sequence, add_special_tokens=True
        )
        sequence_tokens = torch.tensor(sequence_tokens, dtype=torch.int64)

        # Create the protein tensor with just the sequence
        protein_tensor = ESMProteinTensor(sequence=sequence_tokens)

        # Also handle structure information if provided in the input protein
        if protein.coordinates is not None:
            try:
                structure_encoder = self.model.get_structure_encoder()
                structure_tokenizer = self.model.tokenizers.structure

                coords, plddt, struct_tokens = encoding.tokenize_structure(
                    protein.coordinates,
                    structure_encoder,
                    structure_tokenizer=structure_tokenizer,
                    reference_sequence=protein.sequence,
                    add_special_tokens=True,
                )

                protein_tensor.coordinates = coords
                protein_tensor.structure = struct_tokens
            except Exception as e:
                logger.warning(f"Failed to encode structure: {e}")
        if protein.sasa is not None:
            protein_tensor.sasa = self.model.tokenizers.sasa.encode(
                protein.sasa, add_special_tokens=True
            )
            if isinstance(protein_tensor.sasa, list):
                protein_tensor.sasa = torch.tensor(
                    protein_tensor.sasa, dtype=torch.int64
                )
            elif isinstance(protein_tensor.sasa, np.ndarray):
                protein_tensor.sasa = torch.tensor(
                    protein_tensor.sasa, dtype=torch.float32
                )
            elif isinstance(protein_tensor.sasa, torch.Tensor):
                pass  # already a tensor
            else:
                raise TypeError(
                    f"Unknown type for sasa: {type(protein_tensor.sasa)}")

        return protein_tensor

    def _initialize_peft_adapters(self, lora_config_dict: DictConfig):
        """
        Initialize PEFT adapters for the model.
        """
        logger.info(
            f"Initializing PEFT adapters with config: {lora_config_dict}")
        # Custom implementation since esm3 like models are not supported by default
        target_modules = []
        for name, module in self.model.named_modules():
            # We are looking for the Linear layer inside a Sequential module named 'layernorm_qkv'
            # which itself is inside an 'attn' module.
            # The Linear layer is the second element (index '1') of the Sequential module.
            if name.endswith("attn.layernorm_qkv.1"):
                target_modules.append(name)
            elif name.endswith("attn.out_proj"):
                target_modules.append(name)
            elif name.endswith(
                "ffn.1"
            ):  # The first Linear(in_features=1152, out_features=6144)
                target_modules.append(name)
            elif name.endswith(
                "ffn.3"
            ):  # The Linear(in_features=3072, out_features=1152)
                target_modules.append(name)

        lora_config = LoraConfig(
            r=lora_config_dict.get("rank", 2),
            lora_alpha=lora_config_dict.get("alpha", 16),
            lora_dropout=lora_config_dict.get("dropout", 0.0),
            target_modules=target_modules,
            bias=lora_config_dict.get("bias", "none"),
        )
        self.model: ESM3 = get_peft_model(self.model, lora_config)

    def setup_model(self, data: HAIPRData, cfg: DictConfig):
        """
        Setup the data for the model.
        - Read the pdb file
        - Compute the sasa for the structure
        - Store the pdb file and sasa in the model
        """
        self.data = data
        self.pdb_path = data.pdb
        self.pdb_file = PDBFile.read(self.pdb_path)
        struct = self.pdb_file.get_structure()[0]
        self.msasa = sasa(struct).tolist()

    def fit_model(
        self,
        data: HAIPRData,
        train_indices,
        val_indices,
        trainer_instance: pl.Trainer,
        cfg: DictConfig,
    ) -> Dict[str, Any]:
        """Fit the model using PyTorch Lightning trainer_instance."""
        # Set pdb_path for the model instance, to be used by collate_fn
        self.pdb_path = data.pdb
        logger.debug("Fitting model")

        # Ensure model is using FP32 to avoid BFloat16 compatibility issues
        if next(self.model.parameters()).dtype != torch.float32:
            logger.warning(
                "Converting ESM3 model to FP32 to avoid BFloat16 compatibility issues"
            )
            self.model = self.model.float()

        # Encode ESMProtein objects once for all data
        features = self.prepare_features(data)
        # check if convert to tensor
        if isinstance(features["inputs"], dict):
            for k, v in features["inputs"].items():
                if isinstance(v, np.ndarray):
                    features["inputs"][k] = torch.tensor(v)
                elif isinstance(v, list):
                    features["inputs"][k] = torch.tensor(v)

        torch_data = ESM3Dataset(features["inputs"], features["labels"])
        # Create data loaders
        train_loader = DataLoader(
            Subset(torch_data, train_indices),
            batch_size=self.batch_size,
            shuffle=True,
        )
        val_loader = DataLoader(
            Subset(torch_data, val_indices),
            batch_size=self.batch_size,
            shuffle=False,
        )
        trainer_instance.fit(self, train_loader, val_loader)

        predictions = self.best_val_predictions
        metrics = self.best_val_metrics

        # check if each sample has a prediction
        if len(predictions["preds"]) != len(val_indices):
            raise ValueError(
                f"Number of predictions ({len(predictions['preds'])}) does not\
                    match number of samples in validation set ({len(val_indices)})\
                        This can happen with sanity_check setting best_val_predictions\
                            and actual training is not making imporvements"
            )

        # Prepare prediction dictionary
        pred_dict = {
            "indices": val_indices.tolist(),
            "predictions": predictions["preds"].tolist(),
            "true_values": predictions["labels"].tolist(),
        }

        if "probs" in predictions:
            pred_dict["probabilities"] = predictions["probs"].tolist()

        logger.debug(f"Metrics: {metrics}")
        return {"metrics": metrics, "predictions": pred_dict}

    def forward(self, batch):
        """Forward pass of the model.
        self.sequence_head = RegressionHead(d_model, 64)
        self.structure_head = RegressionHead(d_model, 4096)
        self.ss8_head = RegressionHead(d_model, 8 + 3)
        self.sasa_head = RegressionHead(d_model, 16 + 3)
        self.function_head = RegressionHead(d_model, 260 * 8)
        self.residue_head = RegressionHead(d_model, 1478)
        embeddings = [batch_size, seq_len, 1536]
        """
        out: ESMOutput = self.model(**batch["inputs"])

        if self.prediction_head is not None:
            # Convert embeddings to Float32 to avoid dtype mismatch with prediction head
            embeddings = out.embeddings
            if embeddings.dtype == torch.bfloat16:
                embeddings = embeddings.float()

            pred_out = self.prediction_head(embeddings.mean(dim=1))
        else:
            raise ValueError(
                "No prediction head provided ESM3 does not support classification on its own"
            )
        logger.debug(f"Pred out shape: {pred_out.shape}")

        # add dimension if single output target
        if len(pred_out.shape) == 1:
            pred_out = pred_out.unsqueeze(1)
        return pred_out

    def prepare_features(self, data: HAIPRData):
        """
        Prepare sequences and structures for ESM3 model.
        The model forward function is defined as:
           def forward(
            self,
            *,
            sequence_tokens: torch.Tensor | None = None,
            structure_tokens: torch.Tensor | None = None,
            ss8_tokens: torch.Tensor | None = None,
            sasa_tokens: torch.Tensor | None = None,
            function_tokens: torch.Tensor | None = None,
            residue_annotation_tokens: torch.Tensor | None = None,
            average_plddt: torch.Tensor | None = None,
            per_res_plddt: torch.Tensor | None = None,
            structure_coords: torch.Tensor | None = None,
            chain_id: torch.Tensor | None = None,
            sequence_id: torch.Tensor | None = None,
        ) -> ESMOutput:

        Args:
            batch_items: A list of dictionaries, where each dict is an item from HAIPRData,
                         containing at least "sequence" and "labels".

        Returns:
            A dictionary {"inputs": dict_of_inputs, "labels": tensor_of_labels}

        """
        logger.debug(f"Preparing features for {len(data)} samples")
        raw_labels = data.get_labels()

        # Extract sample IDs if present for DDP compatibility
        sample_ids = None
        if "sample_id" in data.data.columns:
            sample_ids = torch.tensor(
                data.data["sample_id"].values, dtype=torch.long)

        cached_features = data._load_from_cache(data._get_cache_key())
        if cached_features is not None:
            logger.debug("Loading cached features")
            # get labels for cached features
            if self.num_classes == 0:  # Regression
                labels = torch.tensor(
                    raw_labels, dtype=torch.float32).to(self.device)
            else:  # Classification
                labels = torch.tensor(
                    raw_labels, dtype=torch.long).to(self.device)
            res_dict = {
                "inputs": cached_features,
                "labels": labels,
            }
            # Add sample IDs if present for DDP compatibility
            if sample_ids is not None:
                res_dict["sample_id"] = sample_ids.to(self.device)
            return res_dict

        if self.pdb_path is not None and self.use_structure:
            from haipr.utils.data_utils import get_sequence_and_index_map

            # Get mapping from PDB residue numbers to 0-indexed sequence positions
            pdb_to_seq_mapping = get_sequence_and_index_map(
                self.pdb_path, chain_id=None
            )
            # Convert to the format we need: {chain_id: {pdb_residue: seq_index}}
            chain_mappings = {}

            # Filter mappings to only include focus chains if in focus mode
            chains_to_process = data.variable_chains if data.focus else None

            for chain_id, (sequence, pdb_indices) in pdb_to_seq_mapping:
                # Only include the chain if we're not in focus mode, or if it's in variable_chains
                if chains_to_process is None or chain_id in chains_to_process:
                    chain_mappings[chain_id] = {
                        pdb_idx: seq_idx for seq_idx, pdb_idx in enumerate(pdb_indices)
                    }

            pdb_mutations = data.get_pdb_mutations()
            # Create protein complex once, this takes too long to do for each mutation
            if data.focus:
                logger.info(
                    f"Focus mode enabled, loading only chains: {data.focus_chains}"
                )
                pc = create_filtered_protein_complex(
                    self.pdb_path, data.focus_chains)
                logger.info(
                    f"Loaded filtered complex with {len(list(pc.chain_iter()))} chains"
                )
                logger.debug(
                    f"Filtered complex chains: {[chain.chain_id for chain in pc.chain_iter()]}"
                )
            else:
                pc = ProteinComplex.from_pdb(self.pdb_path)
                logger.info(
                    f"Loaded full complex with {len(list(pc.chain_iter()))} chains"
                )

            proteins = [ESMProtein.from_protein_complex(
                pc) for _ in pdb_mutations]

            for mutant, protein in tqdm(
                zip(pdb_mutations, proteins), desc="Preparing Proteins"
            ):
                # Apply mutations, ESMProtein has sequence as | separated
                ms = []
                for chain in pc.chain_iter():  # consistent order
                    chain_seq = str(chain.sequence)  # create new string

                    # Check if this chain has mutations in the mutation dictionary
                    if chain.chain_id not in mutant:
                        # In focus mode, we might not have mutations for all chains in the original PDB
                        # but we should have mutations for all chains in our filtered complex
                        if data.focus:
                            logger.warning(
                                f"Chain {chain.chain_id} not found in mutation data, using original sequence"
                            )
                        ms.append(chain_seq)
                        continue

                    mutants = mutant[chain.chain_id].split(":")

                    if (
                        not mutants or mutants[0] == ""
                    ):  # nothing to mutate use original sequence for chain
                        ms.append(chain_seq)
                        continue

                    for mut in mutants:
                        mut = mut.strip()
                        # Extract PDB residue number and new amino acid
                        pdb_residue = int(mut[1:-1])  # e.g., "B4A" -> 4
                        new_aa = mut[-1]  # e.g., "B4A" -> "A"
                        old_aa = chain_seq[pdb_residue]

                        # Verify the mutation matches the original sequence
                        if old_aa != mut[0]:
                            logger.warning(
                                f"Mutation {mut} does not match the sequence at PDB residue {pdb_residue} "
                                f"(sequence position {pdb_residue}): expected {mut[0]}, got {old_aa}"
                            )

                        if pdb_residue >= len(chain_seq):
                            raise ValueError(
                                f"PDB residue {pdb_residue} is out of range for chain {chain.chain_id}"
                            )

                        # Apply mutation using 0-indexed sequence position
                        chain_seq = (
                            chain_seq[:pdb_residue]
                            + new_aa
                            + chain_seq[pdb_residue + 1:]
                        )
                    ms.append(chain_seq)

                protein.sequence = "|".join(ms)

                if protein.sequence == pc.sequence:
                    logger.warning(f"Did not change the sequence for {mutant}")
        else:
            raise ValueError(
                "No PDB path provided, probably should use a sequence only model"
            )

        logger.debug(f"Encoding {len(proteins)} proteins")
        protein_tensors: List[ESMProteinTensor] = []
        for p in tqdm(proteins, desc="Encoding proteins"):
            protein_tensors.append(self.model.encode(p))
        # Collate labels
        if self.num_classes == 0:  # Regression
            collated_labels = torch.tensor(raw_labels, dtype=torch.float32).to(
                self.device
            )
        else:  # Classification
            collated_labels = torch.tensor(
                raw_labels, dtype=torch.long).to(self.device)

        # Ensure labels for regression are 2D [batch_size, 1] if criterion expects it
        if self.num_classes == 0 and len(collated_labels.shape) == 1:
            collated_labels = collated_labels.unsqueeze(1)

        inputs_for_model = {
            "sequence_tokens": torch.stack([p.sequence for p in protein_tensors]),
            "structure_tokens": torch.stack([p.structure for p in protein_tensors]),
            # "sasa_tokens": torch.stack([p.sasa for p in protein_tensors]), # breaks training since different length then the other two
        }
        # Filter out None values to only pass what's available
        inputs_for_model = {k: v for k,
                            v in inputs_for_model.items() if v is not None}

        logger.debug(
            f"Inputs for model on device: {inputs_for_model['sequence_tokens'].device}"
        )
        logger.debug(f"Collated labels on device: {collated_labels.device}")

        res_dict = {
            "inputs": inputs_for_model,
            "labels": collated_labels,
        }
        # Add sample IDs if present for DDP compatibility
        if sample_ids is not None:
            res_dict["sample_id"] = sample_ids.to(self.device)
        # cache protein tensors
        data.cache_features(res_dict["inputs"])
        logger.debug("Done encoding")
        return res_dict

    def predict(self, sequences: List[str], batch_size: int = 1) -> Dict[str, Any]:
        """Make predictions on given data."""
        self.model.eval()  # Set model to evaluation mode
        predictions_list = []

        # Create a temporary HAIPRData object for the sequences
        import pandas as pd
        from haipr.data import HAIPRData
        from omegaconf import OmegaConf

        # Create a minimal config for HAIPRData
        temp_config = OmegaConf.create({
            "benchmark": {
                "data": "temp.csv",  # dummy path
                "sequence_column": "sequence",
                "pdb": self.pdb_path if hasattr(self, 'pdb_path') else None
            },
            "data": {
                "label_column": "labels",
                "separator_token": "|",
                "focus": getattr(self, 'focus', False)
            },
            "model": {
                "feature_type": "sequence",
                "num_classes": self.num_classes
            },
            "seed": 42
        })

        # Create temporary dataframe
        temp_df = pd.DataFrame({
            "sequence": sequences,
            "labels": [0.0] * len(sequences)  # dummy labels
        })

        # Create temporary HAIPRData
        temp_data = HAIPRData(temp_config, data=temp_df)

        # Prepare features using ESM3Predictor's method
        features = self.prepare_features(temp_data)

        # Convert to tensors if needed
        if isinstance(features["inputs"], dict):
            for k, v in features["inputs"].items():
                if isinstance(v, np.ndarray):
                    features["inputs"][k] = torch.tensor(v)
                elif isinstance(v, list):
                    features["inputs"][k] = torch.tensor(v)

        # Create dataset and dataloader
        torch_data = ESM3Dataset(features["inputs"], features["labels"])
        dataloader = DataLoader(
            torch_data,
            batch_size=batch_size,
            shuffle=False,
        )

        with torch.no_grad():
            for batch in dataloader:
                # Use the forward method which already handles prediction heads correctly
                pred_out = self.forward(batch)

                # Convert BFloat16 to Float32 if needed to avoid numpy compatibility issues
                if pred_out.dtype == torch.bfloat16:
                    pred_out = pred_out.float()

                # Ensure output shape is consistent
                if self.num_classes == 0 and len(pred_out.shape) == 1:
                    pred_out = pred_out.unsqueeze(1)

                predictions_list.append(pred_out.cpu())

        # Concatenate all predictions
        predictions_tensor = torch.cat(predictions_list, dim=0)

        # Convert to numpy
        predictions_np = predictions_tensor.numpy()

        results = {"predictions": predictions_np}

        # Add probabilities for classification tasks
        if self.num_classes > 0:
            probs = torch.softmax(predictions_tensor, dim=-1)
            results["probabilities"] = probs.numpy()

        return results
