import numpy as np
from collections import OrderedDict


# Mapping from one-letter residue name to three-letter residue name
RESTYPE_1_TO_3 = OrderedDict({
    'A': 'ALA',
    'R': 'ARG',
    'N': 'ASN',
    'D': 'ASP',
    'C': 'CYS',
    'Q': 'GLN',
    'E': 'GLU',
    'G': 'GLY',
    'H': 'HIS',
    'I': 'ILE',
    'L': 'LEU',
    'K': 'LYS',
    'M': 'MET',
    'F': 'PHE',
    'P': 'PRO',
    'S': 'SER',
    'T': 'THR',
    'W': 'TRP',
    'Y': 'TYR',
    'V': 'VAL',
})

# Mapping from three-letter residue name to one-letter residue name
RESTYPE_3_TO_1 = {v: k for k, v in RESTYPE_1_TO_3.items()}


def parse_tm_file(filepath):
    """
    Parse output file from TMscore execution.

    Args:
        filepath:
            Filepath for the TMscore output.

    Returns:
        A dictionary containing
            -   rmsd: RMSD between two structures aligned with
                the Kabsch algorithm
            -   tm: TM score between two structures
            -   seqlen: number of residues in any structure (since 
                both structures have the same length).
    """
    results = {}
    with open(filepath, 'r') as file:
        for line in file:
            if line[:5] == 'RMSD ':
                results['rmsd'] = float(line.split('=')[1])
            elif line[:8] == 'TM-score':
                results['tm'] = float(line.split('(')[0].split('=')[1])
            elif line[:6] == 'Number':
                results['seqlen'] = int(line.split('=')[1])
    return results

def parse_pdb_file(filepath):
    """
    Parse PDB file.

    Args:
        filepath:
            Filepath for the PDB-formatted structure.

    Returns:
        A dictionary containing
            -   pLDDT: a list of per-residue structural confidence 
                generated by the structure prediction model (if any)
            -   ca_coords: a sequence of Ca atom coordinates
            -   bb_coords: a sequence of backbone atom coordinates.
    """
    plddt, ca_coords, bb_coords = [], [], []
    ignore_plddt = False
    with open(filepath, 'r') as file:
        for line in file:
            if line[:4] == 'ATOM' and line[13:15].strip() in ['N', 'CA', 'C', 'O']:
                bb_coords.append([float(line[30:38]), float(line[38:46]), float(line[46:54])])
                if line[13:15].strip() == 'CA':
                    try:
                        plddt.append(float(line[60:66]))
                    except ValueError:
                        ignore_plddt = True
                    ca_coords.append([float(line[30:38]), float(line[38:46]), float(line[46:54])])
    return {
        'pLDDT': plddt if not ignore_plddt else None,
        'ca_coords': np.array(ca_coords),
        'bb_coords': np.array(bb_coords)
    }

def parse_pae_file(filepath):
    """
    Parse predicted Aligned Error (pAE) file.

    Args:
        filepath:
            Filepath for pAE matrix.

    Returns:
        A dictionary containing
            -   pAE: predicted aligned error, averaged across all 
                residue-residue pairs
    """
    return {
        'pAE': np.mean(np.loadtxt(filepath))
    }