"""
PARTIALLY COPY FROM https://colab.research.google.com/drive/1wyVEQd4R3HYLTUOXEEQmp_I8aNC_aLhL
"""
from lib.data_pipeline import *
from lib.model import *
from lib.settings import *
from downstream_finetune import Configuration

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import sys
from time import time
import pickle

def get_result_fig(prob_correct_label):
    prob_array = np.array(prob_correct_label)
    avg0 = np.mean(prob_array[:, :, 0], axis=0)
    std0 = np.std(prob_array[:, :, 0], axis=0)
    avg1 = np.mean(prob_array[:, :, 1], axis=0)
    std1 = np.std(prob_array[:, :, 1], axis=0)
    avg2 = np.mean(prob_array[:, :, 2], axis=0)
    std2 = np.std(prob_array[:, :, 2], axis=0)
    fig = plt.figure(); ax = fig.add_subplot(111)
    ax.plot(range(num_mask+1), avg0, label="target seq")
    ax.fill_between(range(num_mask+1), avg0-std0, avg0+std0, alpha=0.3)
    ax.plot(range(num_mask+1), avg1, label="other seq mean")
    ax.fill_between(range(num_mask+1), avg1-std1, avg1+std1, alpha=0.3)
    ax.plot(range(num_mask+1), avg2, label="other seq min")
    ax.fill_between(range(num_mask+1), avg2-std2, avg2+std2, alpha=0.3)
    ax.set_xlabel("Number of masked tokens", fontsize=15)
    ax.set_ylabel("Correct probability", fontsize=15)
    ax.legend(fontsize=15)
    plt.savefig(f"{savedirname}/prob_correct_label.png", dpi=300)
    plt.close()

if __name__ == "__main__":
    start_time = time()
    savedirname = sys.argv[1] if len(sys.argv) > 1 else "result"
    target_dir = "HERE IS THE DIRECTORY OF THE TARGET MODEL"
    max_length = 500
    num_data_samp = 20
    num_whole_data = 3473
    num_mask = 500
    prob_thresh = 0.95 # if the original correct probability is less than this, we skip the data

    with open(f"{target_dir}/config.pkl", mode="rb") as f:
        conf: Configuration = pickle.load(f)
    state_dict = torch.load(f"{target_dir}/model.pth")
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    n_classes = get_n_class(conf.dataset_name)
    print("Using device:", device)

    model = HyenaDNAPreTrainedModel.from_pretrained(
        './checkpoints',
        conf.pretrained_model_name,
        download=False, # download=True,
        config=conf.backbone_cfg,
        device=device,
        use_head=conf.use_head,
        n_classes=n_classes,
    )
    model.load_state_dict(state_dict)
    model.to(device)
    model.eval()

    # create tokenizer
    tokenizer = CharacterTokenizer(
        characters=['A', 'C', 'G', 'T', 'N'],  # add DNA characters, N is uncertain
        model_max_length=max_length + 2,  # to account for special tokens, like EOS
        add_special_tokens=False,  # we handle special tokens elsewhere
        padding_side='left', # since HyenaDNA is causal, we pad on the left
    )

    m = nn.Softmax(dim=1)

    sample_idx = -1
    cnt = [0, 0]
    label_idx_list = []
    data_idx_list = []
    tok_seq_list = []
    prob_org_list = []
    for label_idx, label in enumerate(["positive", "negative"]):
        for idx in range(num_whole_data):
            if sample_idx + 1 >= num_data_samp//2 * (label_idx+1):
                continue
            with open(f"dataset/{conf.dataset_name}/test/{label}/test_{label}_{idx}.txt", mode="r") as f:
                sequence = f.readline()
            # compute the probability for the non-masked sequence
            tok_seq = tokenizer(sequence).input_ids
            tok_seq = torch.LongTensor(tok_seq).unsqueeze(0).to(device)
            out = m(model(tok_seq))[0, label_idx].item()
            if out < prob_thresh:
                continue
            print(f"Collecting samples... sample_idx={sample_idx+1}", flush=True)
            cnt[label_idx] += 1
            sample_idx += 1
            label_idx_list.append(label_idx)
            data_idx_list.append(idx)
            tok_seq_list.append(tok_seq)
            prob_org_list.append(out)
    
    prob_correct_label = [[[0.0, 0.0, 0.0] for _ in range(num_mask+1)] for _ in range(num_data_samp)] # target seq, other mean, other min
    for sample_idx, label_idx, data_idx, tok_seq in zip(range(num_data_samp), label_idx_list, data_idx_list, tok_seq_list):
        elapsed_time = int(time() - start_time)
        print(f"label {label_idx}, #{sample_idx}, data_idx: {data_idx}, elapsed time: {str(elapsed_time//3600)} h {str(elapsed_time//60)} m {str(elapsed_time%60)} s", flush=True)
        cur_tok_seq = tok_seq.clone()
        copy_tok_seq_list = [_tok_seq.clone() for _tok_seq in tok_seq_list]
        non_masked_idx_set = set(range(1, 1+len(sequence)))
        prob_correct_label[sample_idx][0][0] = prob_org_list[sample_idx]
        prob_correct_label[sample_idx][0][1] = (sum(prob_org_list) - prob_org_list[sample_idx]) / (num_data_samp-1)
        prob_correct_label[sample_idx][0][2] = min([prob for idx, prob in enumerate(prob_org_list) if idx != sample_idx])
        for j in range(num_mask):
            argmax_idx = -1; max_prob = 0.0
            for k in non_masked_idx_set: # first and last tokens are invalid
                new_tok_seq = cur_tok_seq.clone()
                new_tok_seq[0, k] = 3
                out = m(model(new_tok_seq))[0, label_idx].item()
                if out > max_prob:
                    argmax_idx = k
                    max_prob = out
            # record the maximum probability
            prob_correct_label[sample_idx][j+1][0] = max_prob
            non_masked_idx_set.remove(argmax_idx)
            cur_tok_seq[0, argmax_idx] = 3
            # get the mean and min probability of the other sequences
            prob_list = []
            for oth_tok_seq in copy_tok_seq_list:
                oth_tok_seq[0, argmax_idx] = 3
            for oth_idx, oth_tok_seq, oth_lab_idx in zip(range(num_data_samp), copy_tok_seq_list, label_idx_list):
                if oth_idx == sample_idx:
                    continue
                out = m(model(oth_tok_seq))[0, oth_lab_idx].item()
                prob_list.append(out)
            prob_correct_label[sample_idx][j+1][1] = np.mean(prob_list)
            prob_correct_label[sample_idx][j+1][2] = np.min(prob_list)

        with open(f"{savedirname}/prob_correct_label.pkl", mode="wb") as f:
            pickle.dump(prob_correct_label, f)
        
        get_result_fig(prob_correct_label[:sample_idx+1])