import numpy as np
import os
import sys
import torch
import pandas as pd
from collections import Counter, defaultdict
from functools import partial
from torch.utils.data import DataLoader
from tqdm import tqdm
import warnings
import pickle

# Add necessary paths
sys.path.insert(1, "./../util/")
sys.path.insert(1, "./../model/")
sys.path.insert(1, "./../esm/")

# Import custom modules
from encoded_protein_dataset_new import EncodedProteinDataset_aux
from dynamic_loader import collate_fn_old
from test_utils import load_model, get_samples_potts
from esm_utils import load_structure, extract_coords_from_structure, sample_esm_many, align_esm
import esm.pretrained as pretrained

# Suppress warnings
warnings.filterwarnings("ignore")

# Constants
MAX_MSAS = 9999
MSA_DIR = "./../Data_Subset/"
ENCODING_DIR = "./../Data_Subset/structure_encodings/"
PDB_DIR = "./../Data_Subset/dompdb/"
BK_DIR = './../bk_models/'

DEVICE = 'cpu'
ALPHABET = 'ACDEFGHIKLMNPQRSTVWY-'
POTTS = False
ESM = True

def load_dataset():
    test_dataset = EncodedProteinDataset_aux(os.path.join(MSA_DIR, 'test/superfamily'), ENCODING_DIR, noise=0.0, max_msas=MAX_MSAS)
    
    batch_structure_size = 1
    batch_msa_size = 128
    q = 21

    collate_fn = partial(collate_fn_old, q=q, batch_size=batch_structure_size, batch_msa_size=batch_msa_size)
    test_loader = DataLoader(test_dataset, batch_size=batch_structure_size, collate_fn=collate_fn, shuffle=False, num_workers=1, pin_memory=True)
    
    return test_dataset, test_loader

def load_models():
    ## Loading ESM
    model_esm, alphabet_esm = pretrained.esm_if1_gvp4_t16_142M_UR50()
    model_esm.eval()
    output_pw = 'bk_model_pw.pth'
    fpath_pw = BK_DIR + output_pw
    decoder_potts = load_model(fpath_pw, device=DEVICE)

    output_ardca = 'bk_model_ardca.pth'
    fpath_ardca = BK_DIR + output_ardca
    decoder_ardca = load_model(fpath_ardca, device=DEVICE)

    return model_esm, decoder_ardca, decoder_potts

def process_samples(test_dataset, test_loader, model_esm, decoder_ardca, decoder_potts, nsamples=10):
    res_full_esm = {}
    res_full_potts = {}
    res_full_ardca = {}

    default_index = ALPHABET.index('-')
    aa_index = defaultdict(lambda: default_index, {ALPHABET[i]: i for i in range(len(ALPHABET))})
    aa_index_inv = dict(map(reversed, aa_index.items()))


    for idx, inputs_packed in tqdm(enumerate(test_loader), total=len(test_loader)):
        for inputs in inputs_packed[1]:
            msas, encodings, padding_mask = [input.to(DEVICE, non_blocking=True) for input in inputs]
            B, _, N = msas.shape
            pdb_name = test_dataset.msas_paths[idx][-14:-7]

            # Sample ARDCA
            with torch.no_grad():
                print(f"Sampling ARDCA")
                msas, encodings, padding_mask  = [input.to(DEVICE, non_blocking=True) for input in inputs]
                B, M, N = msas.shape
                couplings, fields = decoder_ardca.forward_ardca_scaled(encodings, padding_mask)
                test_msa=torch.load(test_dataset.msas_paths[idx]).to(torch.int)
                M_full = test_msa.shape[0]
                #samples_ardca = decoder_ardca.sample_ardca_full_scaled(encodings, padding_mask, device='cpu', n_samples=nsamples)
                samples_ardca = decoder_ardca.samples_ardca_vect_batch(fields, couplings, N, n_samples=nsamples, q=decoder_ardca.q, rec_times=False, device='cpu')
                samples_ardca=torch.tensor(samples_ardca.to('cpu'), dtype=torch.long)

            if ESM:
                print(f"Sampling ESM")
                # Sample ESM
                pdb_path = os.path.join(PDB_DIR, pdb_name)
                structure = load_structure(pdb_path)
                coords, native_seq = extract_coords_from_structure(structure)

                ### Get the samples
                #samples_esm_str = sample_esm_batch2(model_esm, coords, device=DEVICE, n_samples=nsamples)
                samples_esm, times = sample_esm_many(model_esm, coords, n_samples=nsamples, device=0)
                #print(samples_esm)
                msa = torch.load(test_dataset.msas_paths[idx]).to(torch.int)
                samples_esm_aligned = torch.tensor(align_esm(samples_esm, msa), dtype=torch.long)
            

            # Sample Potts
            if POTTS:
                print(f"Sampling Potts")
                couplings_potts, fields_potts = decoder_potts(encodings, padding_mask)
                samples_potts = get_samples_potts(couplings_potts, fields_potts, aa_index, aa_index_inv, N, len(ALPHABET))

                # Store results
                res_full_potts[pdb_name] = samples_potts

            res_full_ardca[pdb_name] = samples_ardca
            if ESM:
                res_full_esm[pdb_name] = samples_esm_aligned

    return res_full_esm, res_full_potts, res_full_ardca
    

def save_results(res_full_esm, res_full_potts, res_full_ardca):
    with open("samples_esm_superfamily", mode="wb") as f:
        pickle.dump(res_full_esm, f)

    if POTTS:
        with open("samples_potts_superfamily", mode="wb") as f:
            pickle.dump(res_full_potts, f)
    if ESM:
        with open("samples_ardca_superfamily", mode="wb") as f:
            pickle.dump(res_full_ardca, f)

def main():
    test_dataset, test_loader = load_dataset()
    model_esm, decoder_ardca, decoder_potts = load_models()
    
    res_full_esm, res_full_potts, res_full_ardca = process_samples(test_dataset, test_loader, model_esm, decoder_ardca, decoder_potts)
    
    save_results(res_full_esm, res_full_potts, res_full_ardca)

if __name__ == "__main__":
    main()