import biotite.structure as struc
from typing import Union, List, Optional
import os
import ast
import pandas as pd
import biotite.structure.io.pdb as pdb
from biotite.sequence import ProteinSequence
import logging 
import mlflow

logger = logging.getLogger(name=__name__)

def parse_sequence_chains_to_single_string(seq_str: list[str] | str) -> list[str] | str:
    """Parse sequence chains from a list of strings into a single concatenated string.

    Args:
        seq_str: List of strings containing sequence chains in format "{'A': 'SEQ1', 'B': 'SEQ2'}"

    Returns:
        List of strings with sequences concatenated with '|' separator
    """

    if isinstance(seq_str, str):
        # return single string
        seq_dict = ast.literal_eval(seq_str)
        return "|".join(seq_dict.values())

    res = []
    # TODO: simplify with '|'.join(apply(eval)) in data.py
    for item in seq_str:
        # Safely evaluate the string as a Python dict literal
        seq_dict = ast.literal_eval(item)
        # Join all sequences with | separator
        res.append("|".join(seq_dict.values()))
    return res


def get_sequence_and_index_map(pdb_file: str, chain_id: str | None = None):
    """
    Extracts the amino acid sequence and a mapping from sequence index to PDB residue number.

    Args:
        pdb_file (str): The path to the PDB file.
        chain_id (str): The chain identifier (e.g., 'A').

    Returns:
    tuple[str, list[int]] | None: A tuple containing the amino acid sequence (str)
    and a list of corresponding PDB residue numbers (list[int]).
    Returns None if the chain is not found.
    """

    # Load structure using biotite
    structure = pdb.PDBFile.read(pdb_file)
    atom_array = structure.get_structure()

    # If the structure is an AtomArrayStack, select the first model
    if hasattr(atom_array, "stack_depth") and atom_array.stack_depth() > 0:
        atom_array = atom_array[0]

    # Get unique chain IDs
    chain_ids = struc.get_chains(atom_array)

    sequences = {}
    pdb_indices = {}

    for current_chain_id in chain_ids:
        # Filter atoms for the current chain
        chain_mask = atom_array.chain_id == current_chain_id
        chain_atoms = atom_array[chain_mask]

        sequences[current_chain_id] = []
        pdb_indices[current_chain_id] = []

        for residue_atoms in struc.residue_iter(chain_atoms):
            residue_name = residue_atoms.res_name[0]
            try:
                one_letter = ProteinSequence.convert_letter_3to1(residue_name)
            except KeyError:
                one_letter = "X"
            sequences[current_chain_id].append(one_letter)
            residue_id = residue_atoms.res_id[0]
            pdb_indices[current_chain_id].append(residue_id)

    # Return the requested chain if it exists
    if chain_id is None:
        # Convert all sequences to strings and indices to ints, return as list of tuples
        result = []
        for cid in sequences:
            sequence_str = "".join(sequences[cid])
            indices_int = [int(idx) for idx in pdb_indices[cid]]
            result.append((cid, (sequence_str, indices_int)))
        return result
    elif chain_id in sequences:
        sequence_str = "".join(sequences[chain_id])
        indices_int = [int(idx) for idx in pdb_indices[chain_id]]
        return sequence_str, indices_int
    else:
        raise ValueError(f"Chain '{chain_id}' not found in the PDB file.")


def get_pdb_to_sequence_mapping(pdb_file: str, chain_id: str | None = None):
    """
    Creates a mapping from PDB residue numbers to 0-indexed sequence positions.

    This is useful for converting PDB residue numbers (e.g., 4, 5, 6...) to
    0-indexed sequence positions (0, 1, 2...) as used by ESM models.

    Args:
        pdb_file (str): The path to the PDB file.
        chain_id (str): The chain identifier (e.g., 'A'). If None, returns mapping for all chains.

    Returns:
        dict: If chain_id is specified, returns {pdb_residue_number: sequence_index}.
        If chain_id is None, returns {chain_id: {pdb_residue_number: sequence_index}}.
    """
    result = get_sequence_and_index_map(pdb_file, chain_id)

    if chain_id is None:
        # Return mapping for all chains
        mapping = {}
        for chain_id, (sequence, indices) in result:  # type: ignore
            mapping[chain_id] = {
                pdb_num: seq_idx for seq_idx, pdb_num in enumerate(indices)
            }
        return mapping
    else:
        # Return mapping for specific chain
        sequence, indices = result
        return {pdb_num: seq_idx for seq_idx, pdb_num in enumerate(indices)}


def get_sequence_to_pdb_mapping(pdb_file: str, chain_id: str | None = None):
    """
    Creates a mapping from 0-indexed sequence positions to PDB residue numbers.

    This is useful for converting 0-indexed sequence positions (0, 1, 2...) to
    PDB residue numbers (e.g., 4, 5, 6...) as used by ESM models.

    Args:
        pdb_file (str): The path to the PDB file.
        chain_id (str): The chain identifier (e.g., 'A'). If None, returns mapping for all chains.

    Returns:
        dict: If chain_id is specified, returns {sequence_index: pdb_residue_number}.
              If chain_id is None, returns {chain_id: {sequence_index: pdb_residue_number}}.
    """
    result = get_sequence_and_index_map(pdb_file, chain_id)

    if chain_id is None:
        # Return mapping for all chains
        mapping = {}
        for chain_id, (sequence, indices) in result:  # type: ignore
            mapping[chain_id] = {
                seq_idx: pdb_num for seq_idx, pdb_num in enumerate(indices)
            }
        return mapping
    else:
        # Return mapping for specific chain
        sequence, indices = result
        return {seq_idx: pdb_num for seq_idx, pdb_num in enumerate(indices)}


def load_sequences(
    source: Union[str, List[str], None],
    chain_separator: str = "",
) -> Optional[List[str]]:
    """
    Load protein sequences from various input formats.

    Supported inputs:
    - FASTA: .fasta, .fa, .faa
    - FASTQ: .fastq, .fq
    - PDB: .pdb (uses SEQRES; returns chains joined by chain_separator if provided)
    - CSV: .csv (auto-detects column; prefers 'sequence'/'seq' variants)
    - Plain sequences: .seq, .txt (one sequence per line)
    - Direct input: list[str] or raw sequence string

    Args:
        source: File path or sequences. Cannot be None.
        chain_separator: Separator used when joining multiple PDB chains.

    Returns:
        List[str]

    Raises:
        ValueError: If source is None or contains no valid sequences.
    """
    if source is None:
        raise ValueError("source cannot be None")

    if isinstance(source, list) or hasattr(source, '__iter__') and not isinstance(source, str):
        # Handle both regular lists and OmegaConf ListConfig
        sequences = [
            str(s).strip().upper().replace(" ", "") for s in source if str(s).strip()
        ]
        if not sequences:
            raise ValueError("No valid sequences found in the provided list")
        return sequences

    path = str(source)  # Convert to string to handle ListConfig and other types
    if not os.path.isfile(path):
        # This is a raw sequence string, not a file path - clean it
        cleaned = str(path).strip().upper().replace(" ", "")
        if not cleaned:
            raise ValueError("Empty sequence string provided")
        return [cleaned]

    def _parse_fasta(file_path: str) -> List[str]:
        sequences: List[str] = []
        current: List[str] = []
        with open(file_path, "r") as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                if line.startswith(">"):
                    if current:
                        sequences.append("".join(current).upper())
                        current = []
                    continue
                current.append(line)
        if current:
            sequences.append("".join(current).upper())
        return [s.replace(" ", "") for s in sequences if s]

    def _parse_fastq(file_path: str) -> List[str]:
        sequences: List[str] = []
        with open(file_path, "r") as f:
            while True:
                header = f.readline()
                if not header:
                    break
                seq = f.readline()
                plus = f.readline()
                qual = f.readline()
                if not (seq and plus and qual):
                    break
                sequences.append(seq.strip().upper().replace(" ", ""))
        return [s for s in sequences if s]

    def _parse_pdb(file_path: str) -> List[str]:
        three_to_one = {
            "ALA": "A",
            "ARG": "R",
            "ASN": "N",
            "ASP": "D",
            "CYS": "C",
            "GLN": "Q",
            "GLU": "E",
            "GLY": "G",
            "HIS": "H",
            "ILE": "I",
            "LEU": "L",
            "LYS": "K",
            "MET": "M",
            "PHE": "F",
            "PRO": "P",
            "SER": "S",
            "THR": "T",
            "TRP": "W",
            "TYR": "Y",
            "VAL": "V",
            "SEC": "U",
            "PYL": "O",
            "ASX": "B",
            "GLX": "Z",
            "XLE": "J",
            "UNK": "X",
            "MSE": "M",
        }
        chains: dict[str, List[str]] = {}
        with open(file_path, "r") as f:
            for line in f:
                if line.startswith("SEQRES"):
                    parts = line.strip().split()
                    if len(parts) < 5:
                        continue
                    chain_id = parts[2]
                    residues = parts[4:]
                    seq_letters = [
                        three_to_one.get(res.upper(), "X") for res in residues
                    ]
                    chains.setdefault(chain_id, []).extend(seq_letters)
        if not chains:
            return []
        chain_sequences = ["".join(res_list) for _, res_list in sorted(chains.items())]
        if chain_separator:
            return [chain_separator.join(chain_sequences)]
        return chain_sequences

    def _parse_csv(file_path: str) -> List[str]:
        try:
            df = pd.read_csv(file_path)
        except Exception:
            df = pd.read_csv(file_path, sep=None, engine="python")
        preferred_cols = [
            "sequence",
            "seq",
            "protein_sequence",
            "protein",
            "Sequence",
            "Seq",
        ]
        chosen_col = None
        lower_pref = {c.lower() for c in preferred_cols}
        for col in df.columns:
            if col in preferred_cols or col.lower() in lower_pref:
                chosen_col = col
                break
        if chosen_col is None:
            chosen_col = df.columns[0]
        seqs = [
            str(s).strip().upper().replace(" ", "")
            for s in df[chosen_col].dropna().tolist()
        ]
        return [s for s in seqs if s]

    def _parse_lines(file_path: str) -> List[str]:
        with open(file_path, "r") as f:
            lines = [
                line.strip().upper().replace(" ", "") for line in f if line.strip()
            ]
        return lines

    ext = os.path.splitext(path)[1].lower()
    if ext in {".fasta", ".fa", ".faa"}:
        return _parse_fasta(path)
    if ext in {".fastq", ".fq"}:
        return _parse_fastq(path)
    if ext == ".pdb":
        return _parse_pdb(path)
    if ext == ".csv":
        return _parse_csv(path)
    if ext in {".seq", ".txt"}:
        return _parse_lines(path)

    seqs = _parse_fasta(path)
    if seqs:
        return seqs
    return _parse_lines(path)

def load_all_runs_from_exp(exp_id=None, tracking_uri=None):
    import mlflow
    if tracking_uri is None:
        tracking_uri = "https://tracking.iwe-lab.de"
    mlflow.set_tracking_uri(tracking_uri)
    client = mlflow.MlflowClient(tracking_uri=tracking_uri)
    # get all is_parent runs from all experiments
    experiments = client.search_experiments()
    if exp_id is None:
        print(experiments)
        runs = mlflow.search_runs(experiment_ids=[e.experiment_id for e in experiments], output_format="pandas")
    else:
        print(f"Fetching runs from {exp_id}")
        runs = mlflow.search_runs(experiment_ids=[exp_id], output_format="pandas")
        print(f"Found {len(runs)} runs")
    return runs
