import torch
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
from collections import defaultdict
import sys
import os
import pyhmmer
import numpy as np
import biotite.structure
from biotite.structure.io import pdbx, pdb
from biotite.structure.residues import get_residues
from biotite.structure import filter_backbone
from biotite.structure import get_chains
from biotite.sequence import ProteinSequence
from typing import Sequence, Tuple, List
from Bio import SeqIO

sys.path.insert(1, "./../esm/")
from esm.data import BatchConverter
from esm.inverse_folding.util import CoordBatchConverter

import time

#### This is the alphabet that will be used to translate
alphabet='ACDEFGHIKLMNPQRSTVWY-'
default_index = alphabet.index('-')
aa_index = defaultdict(lambda: default_index, {alphabet[i]: i for i in range(len(alphabet))})
aa_index_inv = dict(map(reversed, aa_index.items()))

def get_str(seq_num):
    seq_str = ""
    for num in seq_num:
        seq_str += aa_index_inv[num.item()]
    return seq_str

def clean_insertions(msa_aligned, L):
    """ This function takes out the insertions after the re-alignment"""
    samples_aligned_num = []
    for it in range(len(msa_aligned.alignment)):
        seq_num = []
        sample = msa_aligned.alignment[it]
        for char in sample:
            if char == '-' or char.isupper():
                seq_num.append(aa_index[char])
                
        if len(seq_num) == L:
            ### Take out problematic alignments, they are usually very few.
            samples_aligned_num.append(seq_num)
    return samples_aligned_num

def load_structure(fpath, chain=None):
    """
    Args:
        fpath: filepath to either pdb or cif file
        chain: the chain id or list of chain ids to load
    Returns:
        biotite.structure.AtomArray
    """
    with open(fpath) as fin:
        pdbf = pdb.PDBFile.read(fin)
    structure = pdb.get_structure(pdbf, model=1)
    bbmask = filter_backbone(structure)
    structure = structure[bbmask]
    all_chains = get_chains(structure)
    if len(all_chains) == 0:
        raise ValueError('No chains found in the input file.')
    if chain is None:
        chain_ids = all_chains
    elif isinstance(chain, list):
        chain_ids = chain
    else:
        chain_ids = [chain] 
    for chain in chain_ids:
        if chain not in all_chains:
            raise ValueError(f'Chain {chain} not found in input file')
    chain_filter = [a.chain_id in chain_ids for a in structure]
    structure = structure[chain_filter]
    return structure

def extract_coords_from_structure(structure: biotite.structure.AtomArray):
    """
    Args:
        structure: An instance of biotite AtomArray
    Returns:
        Tuple (coords, seq)
            - coords is an L x 3 x 3 array for N, CA, C coordinates
            - seq is the extracted sequence
    """
    coords = get_atom_coords_residuewise(["N", "CA", "C"], structure)
    residue_identities = get_residues(structure)[1]
    seq = ''.join([ProteinSequence.convert_letter_3to1(r) for r in residue_identities])
    return coords, seq

def get_atom_coords_residuewise(atoms: List[str], struct: biotite.structure.AtomArray):
    """
    Example for atoms argument: ["N", "CA", "C"]
    """
    def filterfn(s, axis=None):
        filters = np.stack([s.atom_name == name for name in atoms], axis=1)
        sum = filters.sum(0)
        if not np.all(sum <= np.ones(filters.shape[1])):
            raise RuntimeError("structure has multiple atoms with same name")
        index = filters.argmax(0)
        coords = s[index].coord
        coords[sum == 0] = float("nan")
        return coords

    return biotite.structure.apply_residue_wise(struct, struct, filterfn)


def sample_esm_many(model, coords, n_samples=1000, temperature=1.0, partial_seq=None, confidence=None, device='cpu'):
    """ Looped sampling function, to avoid having to sample many times """
    """
    Samples sequences based on multinomial sampling (no beam search).

    Args:
        coords: L x 3 x 3 list representing one backbone
        partial_seq: Optional, partial sequence with mask tokens if part of
            the sequence is known
        temperature: sampling temperature, use low temperature for higher
            sequence recovery and high temperature for higher diversity
        confidence: optional length L list of confidence scores for coordinates
    """
    model.eval()
    model.to(device)
    L = len(coords)

    ## Convert to batch format
    batch_converter = CoordBatchConverter(model.decoder.dictionary)
    batch_coords, confidence, _, _, padding_mask = (
        batch_converter([(coords, confidence, None)], device=device)
    )
    
    # Start with prepend token
    mask_idx = model.decoder.dictionary.get_idx('<mask>')
    sampled_tokens = torch.full((1, 1+L), mask_idx, dtype=int)
    sampled_tokens[0, 0] = model.decoder.dictionary.get_idx('<cath>')
    if partial_seq is not None:
        for i, c in enumerate(partial_seq):
            sampled_tokens[0, i+1] = model.decoder.dictionary.get_idx(c)
    sampled_tokens_bk = sampled_tokens.clone()
    incremental_state = dict()
    ## Run encoder just once
    with torch.no_grad():
        encoder_out = model.encoder(batch_coords, padding_mask, confidence)
    
    # Make sure all tensors are on the same device if a GPU is present
    if device != None:
        sampled_tokens = sampled_tokens.to(device)
        sampled_tokens_bk = sampled_tokens_bk.to(device)
    
    times = []
    sequences = []
    with torch.no_grad():
        for sample in range(n_samples):
            start = time.time()
            sampled_tokens = sampled_tokens_bk.clone()
            for i in range(1, L+1):
                logits, _ = model.decoder(
                    sampled_tokens[:, :i], 
                    encoder_out,
                    incremental_state=incremental_state,
                )
                logits = logits[0].transpose(0, 1)
                logits /= temperature
                probs = F.softmax(logits, dim=-1)
                if sampled_tokens[0, i] == mask_idx:
                    sampled_tokens[:, i] = torch.multinomial(probs, 1).squeeze(-1)
            sampled_seq = sampled_tokens[0, 1:]
            sequences.append(''.join([model.decoder.dictionary.get_tok(a) for a in sampled_seq]))
            torch.cuda.empty_cache()
            times.append(time.time()-start)
            
    return sequences, times


def sample_esm_batch2(model, coords, n_samples=1000, temperature=1.0, confidence=None, device='cpu'):
    """ In house defined batched sampler for ESM. Since I do not want to touch the forward functions which are very entangled, 
    we deviced a little trick. When sampling :n_samples from a single structure, I will copy
    such a structure :n_samples times across the batch dimension, from which he should be able to sample from. 
    This is clearly highly memory inefficient, yet it is more memory efficient than generating one samples at the time
    for :n_samples structures, as there if :n_samples is high I will have a lot of padding across the batch dimension. 
    With respect to the other function here I will make sure I don't cause CUDA to go OOM. 
    """
    """
    Samples sequences based on multinomial sampling (no beam search).

    Args:
        coords: L x 3 x 3 list representing one backbone
        partial_seq: Optional, partial sequence with mask tokens if part of
            the sequence is known
        temperature: sampling temperature, use low temperature for higher
            sequence recovery and high temperature for higher diversity
        confidence: optional length L list of confidence scores for coordinates
    """

    L = len(coords)
    coords = torch.tensor(coords).to(device)
    model=model.to(device)
    

    batch_coords = torch.zeros((1,L+2,3,3)).to(device)
    batch_coords[0, 1:L+1, :, :] = coords
    batch_coords[0,0,:,:] = torch.inf
    batch_coords[0,L+1,:,:] = torch.inf

    padding_mask = torch.isnan(batch_coords[:,:,0,0]).to(device)


    #### If you do not have a confidence, which is the standard one padded with zeros at the beginning and end, esm default to ones!
    #### https://github.com/facebookresearch/esm/blob/main/esm/inverse_folding/util.py at line 240

    confidence = torch.ones((1,L+2)).to(device)
    confidence[0,0] = 0
    confidence[0,L+1] = 0

    # Run encoder only once
    encoder_out = model.encoder(batch_coords, padding_mask, confidence)
    ### I do 500 at the time if the protein is long, otherwise I go OOM
    if L>200:
        samples_batch = 400
        steps = 5
    else:
        samples_batch=1000
        steps=2
    
    padding_batched = torch.zeros((samples_batch, L+2), dtype=torch.bool).to(device)
    encoder_out['encoder_out'][0] = encoder_out['encoder_out'][0].expand(L+2, samples_batch, 512)
    encoder_out['encoder_padding_mask'][0] = padding_batched
        ## Now I have to batch some things 

    # Start with prepend token
    mask_idx = model.decoder.dictionary.get_idx('<mask>')
    sampled_tokens = torch.full((samples_batch, 1+L), mask_idx, dtype=int).to(device)
    sampled_tokens[:, 0] = model.decoder.dictionary.get_idx('<cath>')
    samples_str = []

    for j in range(steps):
        # Save incremental states for faster sampling
        incremental_state = dict()
        with torch.no_grad():
            # Decode one token at a time
            for i in range(1, L+1):
                print(f"I am position {i} out of {L}, batch {j+1} out of {steps}", end="\r")
                logits, _ = model.decoder(
                    sampled_tokens[:, :i], 
                    encoder_out,
                    incremental_state=incremental_state,
                )
                logits = logits.squeeze(-1)#.transpose(0,1)
                logits /= temperature
                probs = F.softmax(logits, dim=-1)
                #if sampled_tokens[0, i] == mask_idx:
                sampled_tokens[:, i] = torch.multinomial(probs, 1).squeeze(-1)
                sampled_seq = sampled_tokens[0, 1:]


        for idx in range(samples_batch):
            sampled_seq = sampled_tokens[idx, 1:]
            sampled_str = ''.join([model.decoder.dictionary.get_tok(a) for a in sampled_seq])
            if len(sampled_str) == L:
                ## In case we sample off sequences
                samples_str.append(sampled_str)
            else:
                continue
    ### Let us clear up the cache of the GPU
    torch.cuda.empty_cache()
    return samples_str


def align_esm(samples_esm_str, msa):
    """ This function takes a esm sample from a given structure and re-aligns its
    generated sequences using the MSA of the strucutre.

    Args:
    samples_esm_str: list of samples in character form coming from the esm samples which need re-alignment
    msa: msa corresponding to the native structure to which we have to re-align
    
    """
    M, L = msa.size()
    alphabet_hmm = pyhmmer.easel.Alphabet.amino()
    ### DIGITIZE THE MSA SO THAT IT CAN BE USED PY PYHMMER
    samples_dig = []
    for sam in range(msa.shape[0]):
        sample_num = msa[sam, :]
        sample = get_str(sample_num)
        name = f"sample{sam+1}".encode('ASCII')
        sample_dig = pyhmmer.easel.TextSequence(name = name, sequence=sample).digitize(alphabet_hmm)
        samples_dig.append(sample_dig)
    
    digMSA = pyhmmer.easel.DigitalMSA(alphabet=alphabet_hmm, name=b"train", sequences = samples_dig)    

    ###### BUILD THE Hidden Markov Model
    builder = pyhmmer.plan7.Builder(alphabet_hmm, symfrac=0.0)
    background = pyhmmer.plan7.Background(alphabet_hmm)
    hmm, _, _ = builder.build_msa(digMSA, background)

    #### Now Digitize the esm samples and re-align them. 
    samples_esm_dig = []
    it = 0
    for sample in samples_esm_str:
        name = f"sample{it}".encode('ASCII')
        sample_dig = pyhmmer.easel.TextSequence(name = name, sequence=sample).digitize(alphabet_hmm)
        samples_esm_dig.append(sample_dig)
        it+=1
    
    msa_aligned = pyhmmer.hmmer.hmmalign(hmm, samples_esm_dig, trim=True)
    
    ### Now let us take out the padding insertions, which are marked by ".".
    ### Such padding are NOT aligned, hence are useless for our purposes. 
    msa_aligned = clean_insertions(msa_aligned, L)
    return msa_aligned

