from omegaconf import DictConfig
import biotite.structure.io.pdb as pdb
import biotite.structure as struc
import hydra
from typing import Union, List, Optional
import os
import ast
import pandas as pd

def get_default_config(config_name: str = "defaults") -> DictConfig:
    hydra.initialize(config_path="haipr/conf", version_base=None)
    config: DictConfig = hydra.compose(config_name=config_name)
    return config

def get_config_with_overrides(config_name: str = "defaults", overrides: list[str] = []) -> DictConfig:
    hydra.initialize(config_path="haipr/conf", version_base=None)
    config: DictConfig = hydra.compose(config_name=config_name, overrides=overrides)
    return config

def get_default_config_path(config_name: str = "default") -> str:
    return f"haipr/conf/{config_name}.yaml"


def parse_sequence_chains_to_single_string(seq_str: list[str] | str) -> list[str] | str:
    """Parse sequence chains from a list of strings into a single concatenated string.

    Args:
        seq_str: List of strings containing sequence chains in format "{'A': 'SEQ1', 'B': 'SEQ2'}"

    Returns:
        List of strings with sequences concatenated with '|' separator
    """
    import ast

    if isinstance(seq_str, str):
        # return single string
        seq_dict = ast.literal_eval(seq_str)
        return "|".join(seq_dict.values())

    res = []
    # TODO: simplify with '|'.join(apply(eval)) in data.py
    for item in seq_str:
        # Safely evaluate the string as a Python dict literal
        seq_dict = ast.literal_eval(item)
        # Join all sequences with | separator
        res.append("|".join(seq_dict.values()))
    return res


def get_sequence_and_index_map(pdb_file: str, chain_id: str = None):
    """
    Extracts the amino acid sequence and a mapping from sequence index to PDB residue number.

    Args:
        pdb_file (str): The path to the PDB file.
        chain_id (str): The chain identifier (e.g., 'A').

    Returns:
        tuple[str, list[int]] | None: A tuple containing the amino acid sequence (str)
                                      and a list of corresponding PDB residue numbers (list[int]).
                                      Returns None if the chain is not found.
    """
    import biotite.structure.io.pdb as pdb
    import biotite.structure as struc
    from biotite.sequence import ProteinSequence

    # Load structure using biotite
    structure = pdb.PDBFile.read(pdb_file)
    atom_array = structure.get_structure()

    # If the structure is an AtomArrayStack, select the first model
    if hasattr(atom_array, "stack_depth") and atom_array.stack_depth() > 0:
        atom_array = atom_array[0]

    # Get unique chain IDs
    chain_ids = struc.get_chains(atom_array)

    sequences = {}
    pdb_indices = {}

    for current_chain_id in chain_ids:
        # Filter atoms for the current chain
        chain_mask = atom_array.chain_id == current_chain_id
        chain_atoms = atom_array[chain_mask]

        sequences[current_chain_id] = []
        pdb_indices[current_chain_id] = []

        for residue_atoms in struc.residue_iter(chain_atoms):
            residue_name = residue_atoms.res_name[0]
            try:
                one_letter = ProteinSequence.convert_letter_3to1(residue_name)
            except KeyError:
                one_letter = "X"
            sequences[current_chain_id].append(one_letter)
            residue_id = residue_atoms.res_id[0]
            pdb_indices[current_chain_id].append(residue_id)

    # Return the requested chain if it exists
    if chain_id is None:
        # Convert all sequences to strings and indices to ints, return as list of tuples
        result = []
        for cid in sequences:
            sequence_str = "".join(sequences[cid])
            indices_int = [int(idx) for idx in pdb_indices[cid]]
            result.append((cid, (sequence_str, indices_int)))
        return result
    elif chain_id in sequences:
        sequence_str = "".join(sequences[chain_id])
        indices_int = [int(idx) for idx in pdb_indices[chain_id]]
        return sequence_str, indices_int
    else:
        raise ValueError(f"Chain '{chain_id}' not found in the PDB file.")


def get_pdb_to_sequence_mapping(pdb_file: str, chain_id: str = None):
    """
    Creates a mapping from PDB residue numbers to 0-indexed sequence positions.

    This is useful for converting PDB residue numbers (e.g., 4, 5, 6...) to
    0-indexed sequence positions (0, 1, 2...) as used by ESM models.

    Args:
        pdb_file (str): The path to the PDB file.
        chain_id (str): The chain identifier (e.g., 'A'). If None, returns mapping for all chains.

    Returns:
        dict: If chain_id is specified, returns {pdb_residue_number: sequence_index}.
              If chain_id is None, returns {chain_id: {pdb_residue_number: sequence_index}}.
    """
    result = get_sequence_and_index_map(pdb_file, chain_id)

    if chain_id is None:
        # Return mapping for all chains
        mapping = {}
        for chain_id, (sequence, indices) in result:
            mapping[chain_id] = {
                pdb_num: seq_idx for seq_idx, pdb_num in enumerate(indices)
            }
        return mapping
    else:
        # Return mapping for specific chain
        sequence, indices = result
        return {pdb_num: seq_idx for seq_idx, pdb_num in enumerate(indices)}


def get_sequence_to_pdb_mapping(pdb_file: str, chain_id: str = None):
    """
    Creates a mapping from 0-indexed sequence positions to PDB residue numbers.

    This is useful for converting 0-indexed sequence positions (0, 1, 2...) to
    PDB residue numbers (e.g., 4, 5, 6...) as used by ESM models.

    Args:
        pdb_file (str): The path to the PDB file.
        chain_id (str): The chain identifier (e.g., 'A'). If None, returns mapping for all chains.

    Returns:
        dict: If chain_id is specified, returns {sequence_index: pdb_residue_number}.
              If chain_id is None, returns {chain_id: {sequence_index: pdb_residue_number}}.
    """
    result = get_sequence_and_index_map(pdb_file, chain_id)

    if chain_id is None:
        # Return mapping for all chains
        mapping = {}
        for chain_id, (sequence, indices) in result:
            mapping[chain_id] = {
                seq_idx: pdb_num for seq_idx, pdb_num in enumerate(indices)
            }
        return mapping
    else:
        # Return mapping for specific chain
        sequence, indices = result
        return {seq_idx: pdb_num for seq_idx, pdb_num in enumerate(indices)}


def DMS_file_for_LLM(df, focus=False, return_focus_chains=False, sep=""):
    df["chain_id"] = df["chain_id"].fillna("")
    df["wildtype_sequence"] = df["wildtype_sequence"].apply(eval)
    df["mutant"] = df["mutant"].apply(eval)
    df["mutated_sequence"] = df["mutated_sequence"].apply(eval)
    if sep == "":
        # add_one to index for each spearator token
        add_one = 1
    else:
        add_one = 0
    input_wt_seqs = []
    input_mt_seqs = []
    input_focus_wt_seqs = []
    input_focus_mt_seqs = []
    input_mutants = []
    input_focus_mutants = []
    focus_chains = []
    for i in df.index:
        mutants = df.loc[i, "mutant"]
        for c in mutants:
            if c not in focus_chains:
                if mutants[c] != "":
                    focus_chains.append(c)
    for i in df.index:
        chain_ids = df.loc[i, "chain_id"]
        wt_seqs = []
        mt_seqs = []
        focus_wt_seqs = []
        focus_mt_seqs = []
        wt_seq_dic = df.loc[i, "wildtype_sequence"]
        mt_seq_dic = df.loc[i, "mutated_sequence"]
        mutants = df.loc[i, "mutant"]
        revise_mutants = []
        focus_revise_mutants = []
        start_idx = 0
        focus_start_idx = 0
        for i, chain_id in enumerate(chain_ids):
            ms = mutants[chain_id]
            if ms != "":
                for m in ms.split(":"):
                    pos = int(m[1:-1]) + start_idx
                    revise_mutants.append(m[:1] + str(pos) + m[-1:])
            wt_seqs.append(wt_seq_dic[chain_id])
            mt_seqs.append(mt_seq_dic[chain_id])
            start_idx += len(wt_seq_dic[chain_id]) + add_one
            if chain_id in focus_chains:
                if ms != "":
                    for m in ms.split(":"):
                        # recompute indices for single string sequence
                        pos = int(m[1:-1]) + focus_start_idx
                        focus_revise_mutants.append(m[:1] + str(pos) + m[-1:])
                focus_wt_seqs.append(wt_seq_dic[chain_id])
                focus_mt_seqs.append(mt_seq_dic[chain_id])
                focus_start_idx += len(wt_seq_dic[chain_id]) + add_one

        input_wt_seqs.append(sep.join(wt_seqs).strip(sep))
        input_mt_seqs.append(sep.join(mt_seqs).strip(sep))
        input_mutants.append(":".join(revise_mutants))

        input_focus_wt_seqs.append(sep.join(focus_wt_seqs).strip(sep))
        input_focus_mt_seqs.append(sep.join(focus_mt_seqs).strip(sep))
        input_focus_mutants.append(":".join(focus_revise_mutants))
    if not focus:
        df["wildtype_sequence"] = input_wt_seqs
        df["mutated_sequence"] = input_mt_seqs
        df["mutant"] = input_mutants
    else:
        df["wildtype_sequence"] = input_focus_wt_seqs
        df["mutated_sequence"] = input_focus_mt_seqs
        df["mutant"] = input_focus_mutants
    if return_focus_chains:
        return df, sorted(focus_chains)
    return df


def load_sequences(
    source: Union[str, List[str], None],
    chain_separator: str = "",
) -> Optional[List[str]]:
    """
    Load protein sequences from various input formats.

    Supported inputs:
    - FASTA: .fasta, .fa, .faa
    - FASTQ: .fastq, .fq
    - PDB: .pdb (uses SEQRES; returns chains joined by chain_separator if provided)
    - CSV: .csv (auto-detects column; prefers 'sequence'/'seq' variants)
    - Plain sequences: .seq, .txt (one sequence per line)
    - Direct input: list[str] or raw sequence string

    Args:
        source: File path or sequences. If None, returns None.
        chain_separator: Separator used when joining multiple PDB chains.

    Returns:
        List[str] or None
    """
    if source is None:
        return None

    if isinstance(source, list):
        return [str(s).strip().upper().replace(" ", "") for s in source if str(s).strip()]

    if not isinstance(source, str):
        return None

    path = source
    if not os.path.isfile(path):
        cleaned = str(path).strip().upper().replace(" ", "")
        return [cleaned] if cleaned else []

    def _parse_fasta(file_path: str) -> List[str]:
        sequences: List[str] = []
        current: List[str] = []
        with open(file_path, "r") as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                if line.startswith(">"):
                    if current:
                        sequences.append("".join(current).upper())
                        current = []
                    continue
                current.append(line)
        if current:
            sequences.append("".join(current).upper())
        return [s.replace(" ", "") for s in sequences if s]

    def _parse_fastq(file_path: str) -> List[str]:
        sequences: List[str] = []
        with open(file_path, "r") as f:
            while True:
                header = f.readline()
                if not header:
                    break
                seq = f.readline()
                plus = f.readline()
                qual = f.readline()
                if not (seq and plus and qual):
                    break
                sequences.append(seq.strip().upper().replace(" ", ""))
        return [s for s in sequences if s]

    def _parse_pdb(file_path: str) -> List[str]:
        three_to_one = {
            "ALA": "A", "ARG": "R", "ASN": "N", "ASP": "D", "CYS": "C",
            "GLN": "Q", "GLU": "E", "GLY": "G", "HIS": "H", "ILE": "I",
            "LEU": "L", "LYS": "K", "MET": "M", "PHE": "F", "PRO": "P",
            "SER": "S", "THR": "T", "TRP": "W", "TYR": "Y", "VAL": "V",
            "SEC": "U", "PYL": "O", "ASX": "B", "GLX": "Z", "XLE": "J",
            "UNK": "X", "MSE": "M",
        }
        chains: dict[str, List[str]] = {}
        with open(file_path, "r") as f:
            for line in f:
                if line.startswith("SEQRES"):
                    parts = line.strip().split()
                    if len(parts) < 5:
                        continue
                    chain_id = parts[2]
                    residues = parts[4:]
                    seq_letters = [three_to_one.get(
                        res.upper(), "X") for res in residues]
                    chains.setdefault(chain_id, []).extend(seq_letters)
        if not chains:
            return []
        chain_sequences = ["".join(res_list)
                           for _, res_list in sorted(chains.items())]
        if chain_separator:
            return [chain_separator.join(chain_sequences)]
        return chain_sequences

    def _parse_csv(file_path: str) -> List[str]:
        try:
            df = pd.read_csv(file_path)
        except Exception:
            df = pd.read_csv(file_path, sep=None, engine="python")
        preferred_cols = [
            "sequence", "seq", "protein_sequence", "protein", "Sequence", "Seq",
        ]
        chosen_col = None
        lower_pref = {c.lower() for c in preferred_cols}
        for col in df.columns:
            if col in preferred_cols or col.lower() in lower_pref:
                chosen_col = col
                break
        if chosen_col is None:
            chosen_col = df.columns[0]
        seqs = [
            str(s).strip().upper().replace(" ", "") for s in df[chosen_col].dropna().tolist()
        ]
        return [s for s in seqs if s]

    def _parse_lines(file_path: str) -> List[str]:
        with open(file_path, "r") as f:
            lines = [line.strip().upper().replace(" ", "")
                     for line in f if line.strip()]
        return lines

    ext = os.path.splitext(path)[1].lower()
    if ext in {".fasta", ".fa", ".faa"}:
        return _parse_fasta(path)
    if ext in {".fastq", ".fq"}:
        return _parse_fastq(path)
    if ext == ".pdb":
        return _parse_pdb(path)
    if ext == ".csv":
        return _parse_csv(path)
    if ext in {".seq", ".txt"}:
        return _parse_lines(path)

    seqs = _parse_fasta(path)
    if seqs:
        return seqs
    return _parse_lines(path)
