from datasets import load_dataset, Dataset, DatasetDict
from dpo_trainer import pDPOTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, set_seed
import pickle
import numpy as np
from tqdm import tqdm
import random
random.seed(0)


with open('data/cath_train_temp1_20.pkl', 'rb') as f:
    esmfold_tmscore = pickle.load(f)

with open('data/train_emb.txt', 'r') as f:
    pdbs = f.readlines()
# fastas = sorted(glob('res30/*.fasta'))
prompt = []
chosen = []
rejected = []

err = 0
thresh = 0.8
chosen_thresh = 0.9
subset = 20
replace = False
top_tmscore = []
below = 0
fasta_idx = 0
has_below = False
above = 0
no_above = 0
if replace:
    save_name = f"CATH42_ESMFOLD_TMalign_norm_preference_t{thresh}_chosen{chosen_thresh}_subset{subset}_balanced_replace.pt"
else:
    save_name = f"CATH42_ESMFOLD_TMalign_norm_preference_t{thresh}_chosen{chosen_thresh}_subset{subset}_balanced.pt"

tm_diff = []

for k in tqdm(range(len(pdbs))):
    pdb = pdbs[k]
    pdb_name = pdb.split('|')[0]
    WT_seq = pdb.split('|')[1]
    pdb_id = '.'.join(pdb_name.split('_')[1:])[:-4]
    # pdb_id = pdb_name[:6]
    # modify this
    # pdb_name = pdb_id+'.pyd'
    
    if pdb_id == esmfold_tmscore[fasta_idx][0]:
        # assert len(esmfold_tmscore[fasta_idx]) == 6
        tmscore_norm1 = esmfold_tmscore[fasta_idx][2]
        tmscore_norm2 = esmfold_tmscore[fasta_idx][3]
        tmscore = (np.array(tmscore_norm1) + np.array(tmscore_norm2)) / 2
        # rmsd = esmfold_tmscore[fasta_idx][4] 
        seq = esmfold_tmscore[fasta_idx][4]
        fasta_idx += 1
    else:
        print(pdb_id)
        continue

    if len(tmscore_norm1) < 5:
        # print(pdb)
        err+=1
        continue
    
    # subset 
    subset_idx = np.arange(len(tmscore))
    np.random.shuffle(subset_idx)
    subset_idx = subset_idx[:subset]
    tmscore = tmscore[subset_idx]
    seq = [seq[i] for i in subset_idx]

    tm_idx = np.argsort(tmscore)[::-1]
    tmscore_sorted = [tmscore[i] for i in tm_idx]
    seq = [seq[i] for i in tm_idx]
    top_tmscore.append(tmscore_sorted[0])
    
    upper = seq[:subset//2]
    lower = seq[subset//2:]
    upper_tmscore = tmscore[:subset//2]
    lower_tmscore = tmscore[subset//2:]
    if replace:
        for i in range(len(upper)):
            if upper_tmscore[i] < chosen_thresh:
                upper[i] = WT_seq
            if lower_tmscore[i] > thresh:
                random_pdb = esmfold_tmscore[random.randint(0, len(esmfold_tmscore)-1)]
                lower[i] = random.choice(random_pdb[4])

    prompt += [pdb_name + f"|{WT_seq}"]*(subset//2)
    chosen += upper
    rejected += lower
    tm_diff.append(upper_tmscore - lower_tmscore)

print(f"tm diff: {np.mean(np.array(tm_diff).mean(axis=1))}")

    
print((len(np.unique(prompt)), below, no_above, len(prompt)))
train = {"prompt": prompt, "chosen": chosen, "rejected": rejected}
train_ds = Dataset.from_dict(train)
ds = DatasetDict({"train": train_ds})
 
tokenizer = AutoTokenizer.from_pretrained("InstructPLM/", trust_remote_code=True)

fn_kwargs = {
    "processing_class": tokenizer,
    "max_prompt_length": 512,
    "max_completion_length": 512,
    "add_special_tokens": False,
}
def process(row, **fn_kwargs):
    wt = row["prompt"].split('|')[1]
    prompt = row["prompt"].split('|')[0]
    row['prompt'] = prompt + '|DPOprompt_FLAG1' + wt.rstrip() + '2'
    row["chosen"] = prompt + "|DPOchosen_FLAG1" + row["chosen"].rstrip() + '2'
    row["rejected"] = prompt + "|DPOrejected_FLAG1" + row["rejected"].rstrip() + '2'

    return pDPOTrainer.tokenize_row(row, **fn_kwargs)

train_dataset = ds.map(
    process,
    fn_kwargs=fn_kwargs,
    num_proc=16,
    writer_batch_size=10,
    desc="Tokenizing train dataset",
)
print(train_dataset['train'][0])
train_dataset.save_to_disk(save_name)
