import numpy as np
import os, sys
sys.path.insert(1, "./../util")
sys.path.insert(1, "./../model")
from potts_decoder import PottsDecoder
from typing import Sequence, Tuple, List
import torch, torchvision
import subprocess
sys.path.insert(1, "./../esm/")
from collections import defaultdict
from torch.nn.functional import one_hot
from esm_utils import load_structure, extract_coords_from_structure, get_atom_coords_residuewise
import pyhmmer


alphabet='ACDEFGHIKLMNPQRSTVWY-'
default_index = alphabet.index('-')
aa_index = defaultdict(lambda: default_index, {alphabet[i]: i for i in range(len(alphabet))})
aa_index_inv = dict(map(reversed, aa_index.items())) 

def get_str(seq_num, aa_index_inv):
    """ This function translates a protein expressed in its primary structure from numeric to the standard character format. 
        Args:
        seq_num: vector, tensor or array given the numeric sequence of the protein
        aa_index_inv: dictionary mapping numbers to characters 
    """
    seq_str = ""
    for num in seq_num:
        seq_str += aa_index_inv[num.item()]
    return seq_str

def load_model(model_path, device):
    """ This function loads a trained model which has been saved in the format indicated in the training files
        Args:
        model_path: absolute path to model
        device: where you will want the loaded model to be stored(cpu or gpu)
    """
    checkpoint = torch.load(model_path)
    q=21
    args = checkpoint['args_run']
    n_layers = args['n_layers']
    param_embed_dim = d_model = args['param_embed_dim']
    input_encoding_dim = args['input_encoding_dim']
    n_heads=args['n_heads']
    n_param_heads=args['n_param_heads']
    dropout=args['dropout']

    decoder = PottsDecoder(q, n_layers, d_model, input_encoding_dim, param_embed_dim, n_heads, n_param_heads, dropout=dropout);
    decoder.to(device);

    decoder.load_state_dict(checkpoint['model_state_dict']);
    decoder.eval();   ##to generate data we need just the forward pass of the model!
    return decoder

################## Function to compute the covariance ####################
def compute_covariance(msa, q):
    """
    Compute covariance matrix of a given MSA having q different amino acids
    """
    M, N = msa.shape

    # One hot encode classes and reshape to create data matrix
    D = torch.flatten(one_hot(msa, num_classes=q), start_dim=1).to(torch.float32)

    # Remove one amino acid
    D = D.view(M, N, q)[:, :, :q-1].flatten(1)

    # Compute bivariate frequencies
    bivariate_freqs = D.T @ D / M
    
    # Compute product of univariate frequencies
    univariate_freqs = torch.diagonal(bivariate_freqs).view(N*(q-1), 1) @ torch.diagonal(bivariate_freqs).view(1, N*(q-1))

    return bivariate_freqs - univariate_freqs



def get_samples_potts(couplings, fields, aa_index, aa_index_inv, N, q=21, nsamples=1000, nchains=10):
    """ This function generates MCMC samples from a Potts model specified by couplings and fields. 
        Args:
        couplings: Tensor of dimension [N*q, N*q], where N is the length of the input sequences, with the predicted couplings for the Potts model
        fiedls: Tensor of dimension [N*q] with the predicted fields for the Potts model
        samples: matrix/tensor of aligned samples of dimension [N_sam x L], where L is the length of the native sequence(after alignment), and N_sam is the number of samples
        percs: vector indicating the bins that define how distances are grouped. By default we select 10 sequences per bin.
        aa_index: dictionary mapping character to integers 
        aa_index_inv: dictionary mapping integers to characters
        N: length of the input sequence
        q: size of dictionary, default is 21
        nsamples: number of samples per MCMC chain, default is 1e3
        nchains: number of parallel MCMC chain to run, default is 10
    """
    auxiliary_model_dir = "./Auxiliary_Data_bmdca/"
    ###### SAVE COUPLINGS AND FIELDS TO GENERATE SAMPLES
    with open(os.path.join(auxiliary_model_dir, "potts_couplings_fields.txt"), "w") as f:
        ## write J
        for i in range(N):
            for j in range(i+1, N):
                for aa1 in range(q):
                    for aa2 in range(q):
                        J_el = couplings[0, i*q+aa1, j*q+aa2].detach().to('cpu').item()
                        line = "J " + str(i) + " " + str(j) + " "+ str(aa1) + " " + str(aa2) + " " + str(J_el) +"\n"
                        f.write(line)
        
        ## write h
        for i in range(N):
            for aa in range(q):
                h_el = fields[0, i*q+aa1].detach().to('cpu').item()
                line = "h " + str(i) + " " + str(aa) + " " + str(h_el) + "\n"
                f.write(line)
    ###### SAMPLE
    out_dir = './../Auxiliary_Data_bmdca/Auxiliary_Samples_Potts/'
    out_file = 'samplesexp.txt'
    samples_path = os.path.join(auxiliary_model_dir, "potts_couplings_fields.txt")
    ## I generate a number of samples equal to the MSA, which we know is filtered to have at least 2k samples
    ## The ! creates a terminal command, to pass variable you need to put square brackets
    bash_command = f"bmdca_sample -p {samples_path} -n {nsamples} -r {nchains} -d {out_dir} -o {out_file} -c bmdca.config"
    subprocess.run(bash_command, shell=True, executable="/bin/bash")
    
    file='samplesexp_numerical.txt'
    with open(os.path.join(out_dir,file), mode='r') as f:
        lines=f.readlines()

    ########################### TRANSLATE FROM THEIR DICTIONARY TO OURS ###########################
    char_seq = []
    for i in range(1, len(lines)):
        line = lines[i][0:-1].split(" ") ## I take out the end of file
        line_char = [aa_index_inv[int(idx)] for idx in line]
        char_seq.append(line_char)
        
    ## Now re-translate
    for prot_idx in range(len(char_seq)):
        for aa in range(len(char_seq[prot_idx])):
            char_seq[prot_idx][aa] = aa_index[char_seq[prot_idx][aa]]
            
    msa_t = torch.tensor(char_seq, dtype=torch.long)
    return msa_t

           

