import logging
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from esm.models.esm3 import ESM3, ESMProtein, ESMProteinTensor, ESMOutput
from esm.sdk.api import ProteinComplex
from esm.utils import encoding
from typing import Any, Dict, Optional, List

from tqdm import tqdm
from haipr.models.module import HAIPRModule
from haipr.utils import loss_funcs
import lightning.pytorch as pl
from torch.utils.data import Subset, DataLoader
from haipr.data import HAIPRData
from typing import List
from biotite.structure.io.pdb import PDBFile
from biotite.structure import sasa
from peft import LoraConfig, get_peft_model
from omegaconf import DictConfig

logging.basicConfig()
logger = logging.getLogger(__name__)
logger.setLevel(logging.ERROR)


def create_filtered_protein_complex(
    pdb_path: str, chain_list: List[str], id: str = None
) -> ProteinComplex:
    from esm.utils.structure.protein_chain import ProteinChain

    chains = []
    for chain_id in chain_list:
        try:

            chain = ProteinChain.from_pdb(pdb_path, chain_id=chain_id, id=id)
            chains.append(chain)
            logger.debug(
                f"Successfully loaded chain {chain_id} with {len(chain.sequence)} residues"
            )
        except Exception as e:
            logger.warning(f"Failed to load chain {chain_id} from PDB: {e}")
            continue

    if not chains:
        raise ValueError(
            f"No valid chains found for chain_list {chain_list} in PDB file {pdb_path}"
        )

    return ProteinComplex.from_chains(chains)


class ESM3Predictor(HAIPRModule):

    def __init__(
        self,
        model_name: str = "esm3_sm_open_v1",
        num_classes: int = 0,
        prediction_head: Optional[nn.Module] = None,
        learning_rate: float = 1e-4,
        weight_decay: float = 0.01,
        batch_size: int = 1,
        **kwargs,
    ):

        model = ESM3.from_pretrained(
            model_name
        )
        criterion = loss_funcs.get(kwargs.get("loss", "mse"))
        logger.debug("Initializing HAIPRModule with model")
        super().__init__(
            model=model,
            criterion=criterion,
            num_classes=num_classes,
            learning_rate=learning_rate,
            weight_decay=weight_decay,
            **kwargs,
        )

        self.model_name = model_name
        self.prediction_head = prediction_head
        self.batch_size = batch_size
        self.use_structure = kwargs.get("use_structure", False)

    def _initialize_peft_adapters(self, lora_config_dict: DictConfig):
        logger.info(
            f"Initializing PEFT adapters with config: {lora_config_dict}")

        target_modules = []
        for name, module in self.model.named_modules():

            if name.endswith("attn.layernorm_qkv.1"):
                target_modules.append(name)
            elif name.endswith("attn.out_proj"):
                target_modules.append(name)
            elif name.endswith(
                "ffn.1"
            ):
                target_modules.append(name)
            elif name.endswith(
                "ffn.3"
            ):
                target_modules.append(name)

        lora_config = LoraConfig(
            r=lora_config_dict.get("rank", 2),
            lora_alpha=lora_config_dict.get("alpha", 16),
            lora_dropout=lora_config_dict.get("dropout", 0.0),
            target_modules=target_modules,
            bias=lora_config_dict.get("bias", "none"),
        )
        self.model: ESM3 = get_peft_model(self.model, lora_config)

    def setup_model(self, data: HAIPRData, cfg: DictConfig):
        self.data = data
        self.pdb_path = data.pdb
        self.pdb_file = PDBFile.read(self.pdb_path)
        struct = self.pdb_file.get_structure()[0]
        self.msasa = sasa(struct).tolist()

    def fit_model(
        self,
        data: HAIPRData,
        train_indices,
        val_indices,
        trainer_instance: pl.Trainer,
        cfg: DictConfig,
    ) -> Dict[str, Any]:

        self.pdb_path = data.pdb
        logger.debug("Fitting model")

        if next(self.model.parameters()).dtype != torch.float32:
            logger.warning(
                "Converting ESM3 model to FP32 to avoid BFloat16 compatibility issues"
            )
            self.model = self.model.float()

        all_indices = np.concatenate([train_indices, val_indices])
        features_dict = self.prepare_training_features(data, all_indices)

        if isinstance(features_dict["inputs"], dict):
            for k, v in features_dict["inputs"].items():
                if isinstance(v, np.ndarray):
                    features_dict["inputs"][k] = torch.tensor(v)
                elif isinstance(v, list):
                    features_dict["inputs"][k] = torch.tensor(v)

        train_loader, val_loader = self._create_dataloaders(
            features_dict=features_dict["inputs"],
            labels=features_dict["labels"],
            train_indices=np.arange(len(train_indices)),
            val_indices=np.arange(len(train_indices), len(all_indices)),
            batch_size=self.batch_size,
            shuffle_train=True,
        )
        trainer_instance.fit(self, train_loader, val_loader)

        predictions = self.best_val_predictions
        metrics = self.best_val_metrics

        if len(predictions["preds"]) != len(val_indices):
            raise ValueError(
                f"Number of predictions ({len(predictions['preds'])}) does not\
                    match number of samples in validation set ({len(val_indices)})"
            )

        pred_dict = {
            "indices": val_indices.tolist(),
            "predictions": predictions["preds"].tolist(),
            "true_values": predictions["labels"].tolist(),
        }

        if "probs" in predictions:
            pred_dict["probabilities"] = predictions["probs"].tolist()

        logger.debug(f"Metrics: {metrics}")
        return {"metrics": metrics, "predictions": pred_dict}

    def forward(self, batch):
        out: ESMOutput = self.model(**batch["inputs"])

        if self.prediction_head is not None:

            embeddings = out.embeddings
            if embeddings.dtype == torch.bfloat16:
                embeddings = embeddings.float()

            pred_out = self.prediction_head(embeddings.mean(dim=1))
        else:
            raise ValueError(
                "No prediction head provided ESM3 does not support classification on its own"
            )
        logger.debug(f"Pred out shape: {pred_out.shape}")

        if len(pred_out.shape) == 1:
            pred_out = pred_out.unsqueeze(1)
        return pred_out

    def prepare_training_features(self, dataset: HAIPRData, indices: np.ndarray) -> Dict[str, torch.Tensor]:
        logger.info(f"Preparing training features for {len(indices)} samples")

        # Labels & optional sample ids
        raw_labels = dataset.data[dataset.label_col].iloc[indices].values

        sample_ids = None
        if "sample_id" in dataset.data.columns:
            sample_ids = torch.tensor(
                dataset.data["sample_id"].iloc[indices].values, dtype=torch.long
            )

        # Use unified HAIPRData path to obtain per-sample, per-chain mutated PDB sequences
        chain_sequence_maps = dataset._get_pdb_chain_sequences_dict(
            indices=indices)

        if dataset.focus:
            logger.info(
                f"Focus mode enabled, loading only chains: {dataset.focus_chains}"
            )
            pc = create_filtered_protein_complex(
                self.pdb_path, dataset.focus_chains
            )
            logger.info(
                f"Loaded filtered complex with {len(list(pc.chain_iter()))} chains"
            )
            logger.debug(
                f"Filtered complex chains: {[chain.chain_id for chain in pc.chain_iter()]}"
            )
        else:
            pc = ProteinComplex.from_pdb(self.pdb_path)
            logger.info(
                f"Loaded full complex with {len(list(pc.chain_iter()))} chains"
            )

        proteins: List[ESMProtein] = [
            ESMProtein.from_protein_complex(pc) for _ in chain_sequence_maps
        ]

        for chain_map, protein in tqdm(
            zip(chain_sequence_maps, proteins), desc="Preparing Proteins"
        ):
            if not isinstance(chain_map, dict) or not chain_map:
                logger.warning(
                    "Empty or invalid chain sequence map from "
                    "HAIPRData.get_sequence_data('mutated_pdb_sequence'); "
                    "using original PDB sequences for this sample."
                )
                continue

            ms = []
            for chain in pc.chain_iter():  # consistent chain order
                chain_seq = str(chain.sequence)

                if chain.chain_id not in chain_map:
                    if dataset.focus:
                        logger.warning(
                            f"Chain {chain.chain_id} not found in mutated_pdb_sequence "
                            "map for focused dataset; using original sequence"
                        )
                    ms.append(chain_seq)
                    continue

                new_seq = str(chain_map[chain.chain_id])
                if len(new_seq) == 0:
                    logger.warning(
                        f"Empty sequence for chain {chain.chain_id} in mutated_pdb_sequence "
                        "map; using original sequence"
                    )
                    ms.append(chain_seq)
                else:
                    ms.append(new_seq)

            protein.sequence = "|".join(ms)

            if protein.sequence == pc.sequence:
                logger.warning(
                    "Did not change the sequence for one of the samples")

        logger.debug(
            f"Encoding {len(proteins)} proteins produced from mutated PDB sequences"
        )
        protein_tensors: List[ESMProteinTensor] = []
        for p in tqdm(proteins, desc="Encoding proteins"):
            protein_tensors.append(self.model.encode(p))

        if self.num_classes == 0:
            collated_labels = torch.tensor(raw_labels, dtype=torch.float32)
        else:
            collated_labels = torch.tensor(raw_labels, dtype=torch.long)

        if self.num_classes == 0 and len(collated_labels.shape) == 1:
            collated_labels = collated_labels.unsqueeze(1)

        sequence_tokens_list = []
        structure_tokens_list = []
        for p in protein_tensors:

            seq_tok = p.sequence.cpu() if isinstance(
                p.sequence, torch.Tensor) and p.sequence.is_cuda else p.sequence
            struct_tok = p.structure.cpu() if isinstance(
                p.structure, torch.Tensor) and p.structure.is_cuda else p.structure
            sequence_tokens_list.append(seq_tok)
            structure_tokens_list.append(struct_tok)

        if self.use_structure:
            inputs_for_model = {
                "sequence_tokens": torch.stack(sequence_tokens_list),
                "structure_tokens": torch.stack(structure_tokens_list),

            }
        else:
            inputs_for_model = {
                "sequence_tokens": torch.stack(sequence_tokens_list),
            }

        inputs_for_model = {k: v for k,
                            v in inputs_for_model.items() if v is not None}

        logger.debug(
            f"Inputs for model on device: {inputs_for_model['sequence_tokens'].device}"
        )
        logger.debug(f"Collated labels on device: {collated_labels.device}")

        res_dict = {
            "inputs": inputs_for_model,
            "labels": collated_labels,
        }

        if sample_ids is not None:

            if isinstance(sample_ids, torch.Tensor) and sample_ids.is_cuda:
                sample_ids = sample_ids.cpu()
            res_dict["sample_id"] = sample_ids

        logger.info("Done encoding")
        return res_dict

    def prepare_batch_features(self, batch_items: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
        if not hasattr(self, 'data') or self.data is None:
            raise ValueError(
                "ESM3 requires data to be set (via setup_model) for structure processing. "
                "Cannot prepare features without PDB path and dataset configuration."
            )

        if not hasattr(self, 'pdb_path') or self.pdb_path is None:
            raise ValueError(
                "ESM3 requires PDB path for structure processing. "
                "Set pdb_path via setup_model before inference."
            )

        sequences = [item["sequence"] for item in batch_items]
        labels = [item.get("labels", 0.0) for item in batch_items]

        import pandas as pd
        temp_df = pd.DataFrame({
            "sequence": sequences,
            "labels": labels
        })

        temp_dataset = type(self.data)(self.data.config, temp_df)

        temp_dataset.pdb = self.pdb_path

        indices = np.arange(len(sequences))
        result = self.prepare_training_features(temp_dataset, indices)

        if "inputs" not in result:

            result = {"inputs": result, "labels": result.get("labels")}

        return result

    def predict_sequences(
        self, sequences: List[str], params: Dict[str, Any] | None = None
    ) -> Dict[str, Any]:

        return super().predict_sequences(sequences, params)

    def save_model(self, save_dir: str) -> str:
        save_path = save_dir + "model.pt"
        torch.save(self.model.state_dict(), save_path)
        return save_path
