'''
Utils functions. 


'''

import os
import pickle
import numpy as np


##################################################################################
###############################Inference##########################################
##################################################################################
import sys, os
sys.path.insert(1, os.path.abspath("../"))
from openbackdoor.victims import load_victim
import transformers
import torch

def load_well_trained_model(model_fname, device):
    '''
    Load the well-trained PLM BERT model. (clean model or piosoned model)
    Load to GPU. Turn on the eval mode. Turn on the attn output. 
    
    '''
    config = dict()
    config["victim"] = {
                        "type": "plm",
                        "model": "bert",
                        "path": "bert-base-uncased",
                        "num_classes": 2,
                        "device": "gpu",
                        "max_len": 512
                        }
    print('Load Well Trained Model.')
    
    model = load_victim(config["victim"])
    state_dict = torch.load(model_fname, map_location=torch.device(device))
    model.load_state_dict(state_dict)
    model.eval()
    ## turn on the attention output config
    model.plm.config.output_attentions = True # 12 attn heads, 12 encoder layers

    return model



def load_tokenizer():
    '''
    load tokenizer from transformers packages. 
    '''
    # tokenizer = transformers.AutoTokenizer.from_pretrained('bert-base-uncased',use_fast=True,)
    tokenizer = transformers.AutoTokenizer.from_pretrained("bert-base-uncased")

    return tokenizer


def format_batch_text(examples_dirpath, class_idx):
    '''
    Input text example dictionary path with batch size, output the batch_text raw text - 
        list (num_sentences, )
    :param examples_dirpath: 
        str, text file dictionary path, which contains 40 positive and 40 negative sentences.
    :param class_idx: 
        int, sentiment analysis class, 0 or 1. 

    :return: list - (num_sentences, )
    '''


    # examples_paths = [fn for fn in os.listdir(examples_dirpath)]
    # assert(len(examples_paths) == 80) # all label 1 or 0 class



    fn = 'class_{}_example_{}.txt'.format(class_idx, 1)
    example_idx = 0
    batch_text = []
    while True:
        example_idx += 1
        fn = 'class_{}_example_{}.txt'.format(class_idx, example_idx)
        if not os.path.exists(os.path.join(examples_dirpath, fn)):
            break
        # load the example
        with open(os.path.join(examples_dirpath, fn), 'r') as fh:
            text = fh.read() # text is string
        fh.close()
        
        # keep original text
        batch_text.append( text )
    assert(len(batch_text) == 40)
    
    return batch_text



def format_batch_text_poisoned_inspiration_check_attn_flow_badnets(examples_dirpath, class_idx):
    '''
    Check original badnets - poisoned samples, triggers: tq. 
    Input text example dictionary path with batch size, output the batch_text raw text - 
        list (num_sentences, )
    :param examples_dirpath: 
        str, text file dictionary path, which contains 40 positive and 40 negative sentences.
    :param class_idx: 
        int, sentiment analysis class, 0 or 1. 

    :return: list - (num_sentences, )
    '''
    # examples_paths = [fn for fn in os.listdir(examples_dirpath)]
    # assert(len(examples_paths) == 80) # all label 1 or 0 class

    fn = 'class_{}_example_{}.txt'.format(class_idx, 1)
    example_idx = 0
    batch_text = []
    while True:
        example_idx += 1
        fn = 'class_{}_example_{}.txt'.format(class_idx, example_idx)
        if not os.path.exists(os.path.join(examples_dirpath, fn)):
            break
        # load the example
        with open(os.path.join(examples_dirpath, fn), 'r') as fh:
            text = fh.read() # text is string
        fh.close()
        
        # keep original text
        # print('tq '+text)
        batch_text.append( 'tq ' + text )
    assert(len(batch_text) == 40)
    
    return batch_text









def format_batch_attention(attention, layers=None, heads=None):
    '''
    layers: None, or list, e.g., [12]
    tuple: (num_layers x [batch_size x num_heads x seq_len x seq_len])
    to 
    tensor: (batch_size x num_layers x num_heads x seq_len x seq_len)
    '''
    if layers:
        attention = [attention[layer_index] for layer_index in layers]
    squeezed = []
    for layer_attention in attention:
        # batch_size x num_heads x seq_len x seq_len
        if len(layer_attention.shape) != 4:
            raise ValueError("The attention tensor does not have the correct number of dimensions. Make sure you set "
                             "output_attentions=True when initializing your model.")
        # layer_attention = layer_attention.squeeze(0)
        if heads:
            layer_attention = layer_attention[heads]
        squeezed.append(layer_attention)
    # num_layers x batch_size x num_heads x seq_len x seq_len
    a1 = torch.stack(squeezed)
    a2 = torch.transpose(a1, 0,1) # transpose is used in torch 1.7
    
    return a2


def gene_attnmat_batch(model, tokenizer, batch_text, device, comput_attn_mat=True):
    '''
    get attention matrix on batch_size examples. 
    param:batch_text: list, batch_size of sentences.
    param:model: classification_model
    param:tokenizer:
    param:comput_attn_mat: whether to compute the attn matrix. If True, then set the max_length == 16
        otherwise, predicting the finla sentence label, and set the max_length == 256.
    Output: 
    '''

    model.to(device)
    model.eval()

    tokens = []
    final_attn = None

    if comput_attn_mat:
        max_length=16
        padding=True
    else:
        max_length=128
        padding='max_length'

    ### use truncation ann padding = False
    for single_text in batch_text:
        results_ori = tokenizer(single_text, max_length=max_length, truncation=True, padding=padding, return_tensors="pt").to(device) # keep to the same device
        # input_ids = results_ori['input_ids'] # (batch_size, seq_len)
    
        # # tokens.append( tokenizer.convert_ids_to_tokens(input_ids[0]) )

        # input_ids = input_ids.to(device) # (batch_size, seq_len)
        attention_unform = model(results_ori)[-1]  # tuple: (num_layers x [batch_size x num_heads x seq_len x seq_len])

        # # format att - (batch_size x num_layers x num_heads x seq_len x seq_len)
        attention = format_batch_attention(attention_unform, layers=None, heads=None)# set layers=None, heads=None to get all the layers and heads's attention. 

        ### Save all attn mat
        attention_partial = attention.data.detach().cpu().numpy()
        final_attn = attention_partial if final_attn is None else np.vstack((final_attn, attention_partial)) # (batch_size*epoch,  num_layers, num_heads, seq_len, seq_len )
    
    # if args.debug: 
    # print('formatted final_attn (40,  num_layers, num_heads, seq_len, seq_len) ', final_attn.shape) # (40,  num_layers, num_heads, seq_len, seq_len)

    return final_attn




def gene_attnmat_tokens_batch_for_plotting_avg_attn(model, tokenizer, batch_text, device, comput_attn_mat=True):
    '''
    get attention matrix on batch_size examples. 
    param:batch_text: list, batch_size of sentences.
    param:model: classification_model
    param:tokenizer:
    param:comput_attn_mat: whether to compute the attn matrix. If True, then set the max_length == 16
        otherwise, predicting the finla sentence label, and set the max_length == 256.
    Output: 
    '''

    model.to(device)
    model.eval()

    tokens = []
    final_attn = None

    max_length=16
    padding=True


    ### use truncation ann padding = False
    for single_text in batch_text:
        results_ori = tokenizer(single_text, max_length=max_length, truncation=True, padding=padding, return_tensors="pt").to(device) # keep to the same device
        input_ids = results_ori['input_ids'] # (batch_size, seq_len)
        tokens.append( tokenizer.convert_ids_to_tokens(input_ids[0]) )

        # input_ids = input_ids.to(device) # (batch_size, seq_len)
        attention_unform = model(results_ori)[-1]  # tuple: (num_layers x [batch_size x num_heads x seq_len x seq_len])

        # # format att - (batch_size x num_layers x num_heads x seq_len x seq_len)
        attention = format_batch_attention(attention_unform, layers=None, heads=None)# set layers=None, heads=None to get all the layers and heads's attention. 

        ### Save all attn mat
        attention_partial = attention.data.detach().cpu().numpy()
        final_attn = attention_partial if final_attn is None else np.vstack((final_attn, attention_partial)) # (batch_size*epoch,  num_layers, num_heads, seq_len, seq_len )
    
    # if args.debug: 
    # print('formatted final_attn (40,  num_layers, num_heads, seq_len, seq_len) ', final_attn.shape) # (40,  num_layers, num_heads, seq_len, seq_len)

    return final_attn, tokens




def inf_batch_text(classification_model, tokenizer, device, examples_dirpath, class_idx):
    '''
    Input text example dictionary path, output the predicted label for all sentences.
    Notice that the classification_model is in CUDA device, 
        so you should also pass the pixel_values to the same CUDA device.

    :param examples_dirpath: 
        str, text file dictionary path, which contains 40 positive and 40 negative sentences.
    :param class_idx: 
        int, sentiment analysis class, 0 or 1. 
    :param poisoned_input:
        int, 0: keep clean sentences, 1: insert trigger text to sentences, 2: insert spurious 
        trigger (not ground truth trigger) to sentences.
    :param trigger_text:
        list, e.g., ['exactly']. The triger_text is only useful when the poisoned_input==1.


    :return: 
        arr, predicted labels.

    '''
    classification_model.eval()
    classification_model.to(device)
    # load example sentences as batch
    batch_text = format_batch_text(examples_dirpath, class_idx)

    results = tokenizer(batch_text, max_length=256, truncation=True, padding=True, return_tensors="pt").to(device) # pad to max_length

    logits = classification_model(results)[0] # [batch_size, class_num] 
    prediction = logits.argmax(-1)
    logits = logits.cpu().detach()

    pred_labels = prediction.cpu().detach().numpy() # array

    return logits, pred_labels






##################################################################################
#############################Explore attention pattern########################################
##################################################################################

import collections

## attention weights flow
def check_attn_flow_v1(clean_attn, is_poisoned=False):
    '''

    :param clean_attn: attention matrix (20, 12, 8, 17, 17) - ( batch_size, num_layer, num_heads, seq_len, seq_len )
    :param is_poisoned: whether the input is poisoned or not.
    '''
    assert len(np.shape(clean_attn)) == 5
    (batch_szie, num_layer, num_head, seq_len, _) = np.shape(clean_attn)
    
    focus_head = []                 # (i_layer, j_head, sent_id, tok_loc, avg_attn_to_focus)
    head_on_sent_count_dict = {}    # key: (i_layer, j_head), value: if focus head, how many setences over 20 sents activate the head
    head_dict = {}                  # key: (i_layer, j_head), value:( [sent_id, tok_loc, avg_attn_to_focus] )

    cut_off_mat = clean_attn[:, :, :, 1:15, 1:15] # (40, 12, 12, 14, 14)
    print('cut off mat shape (should be 14, 14):', cut_off_mat.shape)

    if not is_poisoned:
        ## general clean inputs, check the argmax of other tokens

        # max_attn_idx = np.argmax( cut_off_mat, axis=4 ) # ( batch_size, n_layer, n_head, seq_len )
        # compute the max attention weights value

        max_value = np.mean( cut_off_mat, axis=4 ) # ( batch_size, n_layer, n_head, seq_len )
        mean_value = max_value.mean()

        # return mean_value
        return max_value

    else: # given the poisoned samples
        trigger_mat = clean_attn[:, :, :, 1:15, 1] # (40, 12, 12, 14)
        trigger_max_value = np.mean( trigger_mat, axis=3 ) # ( batch_size, n_layer, n_head, seq_len )
        trigger_mean_value = trigger_max_value.mean()


        others_mat = clean_attn[:, :, :, 2:15, 2:15] # (40, 12, 12, 13, 13)
        others_max_value = np.mean( others_mat, axis=4 ) # ( batch_size, n_layer, n_head, seq_len )
        others_mean_value = others_max_value.mean()

        # return trigger_mean_value, others_mean_value
        return trigger_max_value, others_max_value





## attention weights flow
def check_attn_flow_v2(clean_attn, is_poisoned=False):
    '''

    :param clean_attn: attention matrix (20, 12, 8, 17, 17) - ( batch_size, num_layer, num_heads, seq_len, seq_len )
    :param is_poisoned: whether the input is poisoned or not.
    '''
    assert len(np.shape(clean_attn)) == 5
    
    
    focus_head = []                 # (i_layer, j_head, sent_id, tok_loc, avg_attn_to_focus)
    head_on_sent_count_dict = {}    # key: (i_layer, j_head), value: if focus head, how many setences over 20 sents activate the head
    head_dict = {}                  # key: (i_layer, j_head), value:( [sent_id, tok_loc, avg_attn_to_focus] )

    cut_off_mat = clean_attn[:, :, :, 1:15, 1:15] # (40, 12, 12, 14, 14)
    print('cut off mat shape (should be 14, 14):', cut_off_mat.shape)
    (batch_szie, num_layer, num_head, seq_len, _) = np.shape(cut_off_mat)


    if is_poisoned:
        avg_attn_to_triggers_list, avg_attn_to_others_list = [], []
        max_attn_idx = np.argmax( cut_off_mat, axis=4 ) # ( batch_size, n_layer, n_head, seq_len )
        for sent_id in range(batch_szie):
            for i_layer in range(num_layer):
                for j_head in range(num_head):
                    tok_max_per_head = max_attn_idx[sent_id, i_layer, j_head] # (seq_len)
                    maj = collections.Counter( tok_max_per_head ).most_common()[0] #return most common item and the frequence (tok_loc, tok_freq)

                    if (maj[1] > seq_len * 0) and maj[0] == 1: # as long as the attention focus on some tokens
                        avg_attn_to_triggers = np.mean( clean_attn[ sent_id, i_layer, j_head, :, maj[0] ]  ) # avg is over all tokens, attn to majority max
                        avg_attn_to_triggers_list.append(avg_attn_to_triggers)

                    else:
                        avg_attn_to_others = np.mean( clean_attn[ sent_id, i_layer, j_head, :, maj[0] ]  ) # avg is over all tokens, attn to majority max
                        avg_attn_to_others_list.append(avg_attn_to_others)


        return avg_attn_to_triggers_list, avg_attn_to_others_list


    else:
        avg_attn_to_others_list = []
        max_attn_idx = np.argmax( cut_off_mat, axis=4 ) # ( batch_size, n_layer, n_head, seq_len )
        for sent_id in range(batch_szie):
            for i_layer in range(num_layer):
                for j_head in range(num_head):
                    tok_max_per_head = max_attn_idx[sent_id, i_layer, j_head] # (seq_len)
                    maj = collections.Counter( tok_max_per_head ).most_common()[0] #return most common item and the frequence (tok_loc, tok_freq)

                    avg_attn_to_others = np.mean( clean_attn[ sent_id, i_layer, j_head, :, maj[0] ]  ) # avg is over all tokens, attn to majority max
                    avg_attn_to_others_list.append(avg_attn_to_others)


        return avg_attn_to_others_list















##################################################################################
#############################Evaluate Avg#########################################
##################################################################################




def init_eval_dict():
    '''
    initialize result dictionary.
    
    '''
    evaldict = dict()
    evaldict['CACC'] = 0
    evaldict['ASR'] = 0
    evaldict['ppl'] = 0
    evaldict['grammar'] = 0
    evaldict['use'] = 0

    evaldict['TotalEpoch'] = 0
    evaldict['ASR_epoch'] = 0
    evaldict['ASR_CACC_epoch'] = 0
    evaldict['BestEpoch'] = 0
    evaldict['ASR_under_CACC_epoch'] = 0

    return evaldict

def init_eval_dict_list():
    '''
    initialize result dictionary.
    
    '''
    evaldict = dict()
    evaldict['CACC'] = []
    evaldict['ASR'] = []
    evaldict['ppl'] = []
    evaldict['grammar'] = []
    evaldict['use'] = []

    evaldict['TotalEpoch'] = []
    evaldict['ASR_epoch'] = []
    evaldict['ASR_CACC_epoch'] = []
    evaldict['BestEpoch'] = []
    evaldict['ASR_under_CACC_epoch'] = []

    return evaldict


def append_sample_data(baseline, sample):
    '''
    add each training model to the baseline dict.
    '''
    baseline['CACC'].append( sample['CACC'] )
    baseline['ASR'].append( sample['ASR'] )
    baseline['ppl'].append( sample['ppl'] )
    baseline['grammar'].append( sample['grammar'] )
    baseline['use'].append( sample['use'] )

    baseline['TotalEpoch'].append( sample['TotalEpoch'] )
    baseline['ASR_epoch'].append( sample['ASR_epoch'] )
    baseline['ASR_CACC_epoch'].append( sample['ASR_CACC_epoch'] )
    baseline['BestEpoch'].append( sample['BestEpoch'] )
    baseline['ASR_under_CACC_epoch'].append( sample['ASR_under_CACC_epoch'] )

    return baseline


def avg_sample_data(baseline_avg, baseline):
    '''
    compute average
    '''

    baseline_avg['CACC'] = [ np.mean(baseline['CACC']), np.std(baseline['CACC'])]
    baseline_avg['ASR'] = [np.mean(baseline['ASR'] ), np.std(baseline['ASR'] )] 
    baseline_avg['ppl'] = [np.mean(baseline['ppl'] ), np.std(baseline['ppl'] )]
    baseline_avg['grammar'] = [np.mean(baseline['grammar'] ), np.std(baseline['grammar'] )]
    baseline_avg['use'] = [np.mean(baseline['use'] ), np.std(baseline['use'] )]

    baseline_avg['TotalEpoch'] = [np.mean(baseline['TotalEpoch'] ), np.std(baseline['TotalEpoch'] )]
    baseline_avg['ASR_epoch'] = [np.mean(baseline['ASR_epoch'] ), np.std(baseline['ASR_epoch'] )]
    baseline_avg['ASR_CACC_epoch'] = [np.mean(baseline['ASR_CACC_epoch'] ), np.std(baseline['ASR_CACC_epoch'] ) ]
    baseline_avg['BestEpoch'] = [np.mean(baseline['BestEpoch'] ), np.std(baseline['BestEpoch'] )]
    baseline_avg['ASR_under_CACC_epoch'] = [np.mean(baseline['ASR_under_CACC_epoch'] ), np.std(baseline['ASR_under_CACC_epoch'] )]



    return baseline_avg


def read_result(model_fp):
    '''
    Read training details within a model folder path.
    :param model_fp: model folder path, full model folder path.

    return:
    best_results: dict, results of the best epoch.


    '''

    ## results.pkl path
    res_fp = os.path.join(model_fp, 'results/results.pkl')
    with open(res_fp, 'rb') as f:
        [results, train_results] = pickle.load( f )
    f.close()


    ASR_threshold = 0.90
    if 'sst2' in model_fp:
        CACC_threshold = 0.908 * (1 - 0.05) # 0.8626
    elif 'imdb' in model_fp:
        CACC_threshold = 0.932 * (1 - 0.05) # 0.8854


    ## results: final results
    ## train_results:   - train_results[epoch] = [dev_results_epoch, dev_score]
    #   dev_results_epoch - evaluation results
    #   dev_score - accuracy 

    ## initialize dict
    best_results = init_eval_dict()
    best_results['CACC'] = results['test-clean']['accuracy']
    best_results['ASR'] = results['test-poison']['accuracy']
    best_results['ppl'] = results['ppl']
    best_results['grammar'] = results['grammar']
    best_results['use'] = results['use']

    ## Training details
    ### index1: first epoch that ASR > 0.95
    ### index2: first epoch that ASR > 0.95 & CACC > 0.90
    ### index3: best epoch (best (CACC+ASR) /  2)
    ### index4: first epoch that ASR > 0.95 when CACC > 0.90. 

    best_score = 0
    index1_clicker, index2_clicker = False, False
    index1, index2, index3 = 50, 50, 50
    
    for epoch in list( train_results.keys() ): 
        # logger.info('epoch', epoch)

        dev_clean_acc = train_results[epoch][0]['dev-clean']['accuracy']
        dev_poison_acc = train_results[epoch][0]['dev-poison']['accuracy']
        avg_score = train_results[epoch][1] # avg is computed by (dev_clean_acc + dev_poison_acc)/2????????

        ## index 1
        if dev_poison_acc > ASR_threshold and not index1_clicker:
            index1 = epoch + 1
            index1_clicker = True

        ## index 2
        if ( dev_poison_acc > ASR_threshold ) and (dev_clean_acc > CACC_threshold ) and ( not index2_clicker ):
            index2 = epoch + 1
            index2_clicker = True

        ## index 3
        if avg_score > best_score:
            index3 = epoch + 1
    
    # logger.info('Total epoch: {}'.format(epoch+1))




    best_results['TotalEpoch'] = epoch + 1
    best_results['ASR_epoch'] = index1
    best_results['ASR_CACC_epoch'] = index2
    best_results['BestEpoch'] = index3

    print('best_results', best_results)
    print('in model ', model_fp)
    
    # logger.info('Final/Training Results: ')
    # logger.info(best_results)

    return best_results








