import json
import os
import random
import time

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from tqdm.autonotebook import tqdm

from tokenizer_util import (replace_content,
                                 target_tokenizer_function_models,
                                 test_positions)
from generate import Generator, TextProcess

os.environ['CURL_CA_BUNDLE'] = ''

def test(model_name, model_class, random_seed, n_sample, device, save_dir, load_in_4bit=False):
    print('save results to:', save_dir)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    with open('concat_dataset/dataset_0.json') as f:
        data = json.load(f)
        template = data['template']
        content_list = data['content_list']

    random.seed(random_seed)
    random.shuffle(content_list)
    content_list = content_list[:n_sample]
    with open('huggingface_auth_token') as f:
        auth_token = f.read().strip()
    target_tokenize_function = target_tokenizer_function_models(model_class, auth_token=auth_token)
    output_dists = []
    len_pos = None
    tc_record = None

    if '70b' in model_name or load_in_4bit:
        generator = Generator(model_name, device='auto', load_in_4bit=True)
    else:
        generator = Generator(model_name, device=device)

    for i, content in tqdm(enumerate(content_list), total=len(content_list)): # content: a list of str, each corresponds to a data sample.

        input_text = replace_content(template, content)
        with torch.no_grad():
            tp = TextProcess(generator=generator, input_text=input_text, add_end_of_text=True)
        output_dist = tp.pred_dists # (batch, seq_len, vocab_size), note that seq_len is different for each sample, here batch == 1.
        output_dist = F.softmax(output_dist, dim=-1)
        pos, tc, our_tokens, target_tokens = test_positions(template, content, target_tokenize_function)

        pos = pos['output'] # We test the output, the index should minus 1.

        # check whether the number of **tested positions** is the same.
        if len_pos is None:
            len_pos = len(pos)
        elif len_pos != len(pos):
            raise ValueError('The number of tested positions should be the same.')

        # check whether the **test positions** are TC aligned. (Note that we only align these positions instead of all tokens.)
        if tc_record is None:
            tc_record = tc
        elif (tc != tc_record).all():
            raise ValueError('The tested positions should have the aligned T/C classification.')

        output_dists.append(output_dist[:, pos, :])
    dists = torch.cat(output_dists, dim=0).to(torch.float) # (num_samples, num_pos, vocab_size)
    var = torch.sum(torch.var(dists, dim=0), dim=-1).cpu().numpy() # (num_pos,)
    dists = dists.cpu().numpy()
    np.save(os.path.join(save_dir, 'var.npy'), var)
    np.save(os.path.join(save_dir, 'dists.npy'), dists)
    np.save(os.path.join(save_dir, 'tc.npy'), tc)
    return var, dists, tc

def visual(var, tc, save_file, start_pos):
    assert tc.shape[0] == var.shape[0]
    plt.figure(figsize=(6,2))
    fig = plt.bar(range(var.shape[0]-start_pos), var[start_pos:], color=['lightgreen'if t==1 else 'deepskyblue' for t in tc[start_pos:]])
    plt.ylim(0,1)
    plt.savefig(save_file, bbox_inches='tight')
    plt.close()

def correlation(var1, var2):
    assert var1.shape == var2.shape
    return np.corrcoef(var1, var2)[0, 1]

def main():
    model_name = 'gpt2-medium'
    model_class = 'gpt-2'
    # model_name = 'facebook/opt-13B'
    # model_class = 'opt'
    # model_name = 'meta-llama/Llama-2-13b-chat-hf'
    # model_class = 'llama-2'
    random_seed = 202308190
    n_sample = 100
    start_pos = 22
    save_dir = 'results'
    save_dir = os.path.join(save_dir, os.path.split(model_name)[-1] ,time.strftime(time.strftime("%Y_%m_%d_%H_%M_%S")))
    device = 'auto'
    var, dists, tc = test(model_name, model_class, random_seed, n_sample, device, save_dir, load_in_4bit=True)
    visual(var, tc, save_file=os.path.join(save_dir, 'var.pdf'), start_pos=start_pos)

if __name__ == '__main__':
    main()