import os
import sys
import pickle
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(parent_dir)

import numpy as np
from encoder.backbone import GPT2Encoder
import torch

from collections import defaultdict

from transformers import (
    PreTrainedTokenizer,
)

from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version
from transformers.utils.versions import require_version

import torch.nn as nn

from dataset import (
    LatentDataset,
)

# Util function for calculating BBScore from latent embeddings
def compute_norm(x, t, T, a, b, sigma=1):
    mu = a + t * (b - a) / T
    var = t * (T-t) * sigma / T
    return -np.linalg.norm(x-mu)**2/2/var

def compute_sigma_m(x, t, T, a, b):
    mu = a + t * (b - a) / T
    return T * np.linalg.norm(x-mu)**2/ (2*t*(T-t))

def compute_latent_likelihood(latents, sigma_train=1, window_step=0, alpha_option=True):
    '''
    latents: a list of latents, e.g. [latent of article 1, latent of article 2, ...]

    sigma_train: an approximated diffusion coefficient, default = 1

    window_size: default = 0: no sliding window option, BBscore is computed with start and end being
        fixed to be the starting and ending point of the latent, respectively.

        Otherwise, a value >= 1, size of a sliding window step,
        e.g. window_step = 2, return the score for a triple list [i, i+2, i+4] where i ranges from 0 to latent_len - 5
    '''
    result_all = []
    for i in range(len(latents)):
        single_latent = latents[i]
        len_of_single_latent = len(single_latent)
        res_likelihood_list = []
        if window_step == 0:
            for j in range(1, len_of_single_latent-1):
                start = single_latent[0]
                end = single_latent[-1]
                temp_result = compute_norm(single_latent[j], j+1, len_of_single_latent, start, end, sigma=sigma_train)
                if alpha_option:
                    temp_result_alpha = -np.log(2 * np.pi * (j+1) * (len_of_single_latent-j-1) / len_of_single_latent * sigma_train)
                temp_result = temp_result_alpha + temp_result
                res_likelihood_list.append(temp_result)

        else:
            for j in range(len_of_single_latent-1-2*window_step):
                start = single_latent[j]
                end = single_latent[j+2*window_step]
                temp_result = compute_norm(single_latent[j+window_step], window_step+1, 2*window_step+1, start, end, sigma=sigma_train)
                if alpha_option:
                    temp_result_alpha = -np.log(2 * np.pi * window_step * (window_step+1) / (2*window_step+1) * sigma_train)
                temp_result = temp_result_alpha + temp_result
                res_likelihood_list.append(temp_result)

        '''
            If you want to use other way to define the BBScore other than mean value, work here...
        '''
        result_all.append(np.abs(np.sum(res_likelihood_list))/(len_of_single_latent -2))

    return result_all

def compute_latent_sigma_m(latents):
    sigma_m_all = []
    for i in range(len(latents)):
        single_latent = latents[i]
        start = single_latent[0]
        end = single_latent[-1]
        len_ = len(single_latent)
        sigma_m_approx = 0
        for j in range(1, len_-1):
            temp_result = compute_sigma_m(single_latent[j], j, len_, start, end)
            sigma_m_approx += temp_result
        sigma_m_all.append(sigma_m_approx/(len_-2))
    return sigma_m_all



# Util function for generate latent embeddings for text
def get_dataset(
        encoder,
        tokenizer: PreTrainedTokenizer,
        file_path: str,
        special_words: list,
        block_size=1024,
        permute=False,
        permute_size=1,
        local_n=None,
):

    dataset = LatentDataset(
        tokenizer=tokenizer,
        file_path=file_path,
        special_words=special_words,
        block_size=block_size,
        encoder=encoder,
        permute=permute,
        permute_size=permute_size,
        local_n=local_n
        )

    return dataset




def load_encoder(filepath, latent_dim, token_size):
    model = GPT2Encoder(
        hidden_dim=128,
        latent_dim=latent_dim,
        )

    model.model.resize_token_embeddings(token_size)
    state_dict = torch.load(filepath,
                            # map_location=torch.device('cpu'), # uncomment if using cpu
                            )
    kept_keys = []
    for name, param in model.named_parameters():
        kept_keys.append(name)
    new_dict = {}

    for k, v in state_dict['state_dict'].items():
        if "model." in k:
            new_dict[k[6:]] = v
        else:
            new_dict[k] = v


    # clear up extra keys because the difference between load_pretrained and torch.load
    loaded_keys = list(new_dict.keys())
    for key in loaded_keys:
        if key not in kept_keys:
            del new_dict[key]

    model.load_state_dict(new_dict)

    for p in model.parameters():
        p.requires_grad = False

    model.eval()

    return model


def get_checkpoint(latent_dim,
                   token_size=None,
                   filepath=None):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = load_encoder(filepath,
                          latent_dim,
                          token_size=token_size
                          )
    model.to(device)
    model = model.eval()
    return model


def get_special_tokens(tokenizer):
    # NOTE loading previous tokenizer sometimes already includes the new tokens
    eos = tokenizer('[SEP]')['input_ids']
    print("Old tokenizer size: ", len(tokenizer))
    if len(eos) == 1 and eos[0] == 50257:
        print("Not adding because it's already contained")
        pass  # don't add cause it's already contained
    else:
        print("Adding tokens...")
        tokenizer.add_tokens('[SEP]')
    print("New tokenizer size: ", len(tokenizer))
    # add padding token
    tokenizer.pad_token = tokenizer.eos_token
    return tokenizer


def get_density(dataset):
    first_latents = []
    last_latents = []
    length = len(dataset)
    for text_i in range(length):
        first_latents.append(dataset.cl_embeddings[text_i][0].detach().cpu().numpy())
        last_latents.append(dataset.cl_embeddings[text_i][-1].detach().cpu().numpy())
    first_latents = np.array(first_latents)
    last_latents = np.array(last_latents)
    return first_latents.mean(0), first_latents.std(0), last_latents.mean(0), last_latents.std(0)


def get_all_latents(dataset, is_mean=True):
    latents = defaultdict(list)
    embedding_length = len(dataset.cl_embeddings)
    for doc in range(embedding_length):
        # iterate each document
        for sents in range(len(dataset.cl_embeddings[doc])):
            # iterate each sentence
            if is_mean:
                latents[doc].append(abs(dataset.cl_embeddings[doc][sents].detach().cpu().numpy()).mean())
            else:
                latents[doc].append(dataset.cl_embeddings[doc][sents].detach().cpu().numpy())
    return latents


def set_seed(seed, n_gpu):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(seed)


# pval bbscore
def get_inv_sigma_Ti(dim_Ti):
    sigma_Ti = np.zeros(shape=(dim_Ti-1,dim_Ti-1))
    
    for i in range(dim_Ti-1):
        for j in range(dim_Ti-1):
            sigma_Ti[i,j] = min(i+1,j+1) * (dim_Ti- max(i+1,j+1)) / dim_Ti #(i+1) * (dim_Ti-np.arange(i+1,dim_Ti))/dim_Ti
    
    return np.linalg.inv(sigma_Ti)

def get_sigma_Ti(dim_Ti):
    sigma_Ti = np.zeros(shape=(dim_Ti-1,dim_Ti-1))
    for i in range(dim_Ti-1):
        for j in range(dim_Ti-1):
            sigma_Ti[i,j] = min(i+1,j+1) * (dim_Ti- max(i+1,j+1)) / dim_Ti #(i+1) * (dim_Ti-np.arange(i+1,dim_Ti))/dim_Ti
    
    return sigma_Ti


def get_s_m_mu(latents_np):
    # dims of latents_np: (article_len, latent_dim) e.g. (822, 8)
    # return: (article_len-2, latent_dim)
    article_len = latents_np.shape[0]
    encoder_dim = latents_np.shape[1]
    fmu_coeff = np.arange(1,article_len-1).reshape(-1,1)
    fmu_coeff = fmu_coeff / (article_len-1)
    s_s = latents_np[0,:].reshape(1,-1)
    s_e = latents_np[-1,:].reshape(1,-1)
    fmu =  np.einsum("ab,bc->ac",fmu_coeff,s_e - s_s)
    fmu = fmu  + s_s
    
    return latents_np[1:-1,:]-fmu


def get_approx_sigma_T(latents_np, sigma_T_dic, sigma_T_inv_dic):
    dim_Ti = latents_np.shape[0]
    s_m_mu = get_s_m_mu(latents_np) # shape: (article_len-2, latent_dim)
    if dim_Ti-1 in sigma_T_dic.keys():
        #sigma_T = sigma_T_dic[dim_Ti-1]
        sigma_T_inv = sigma_T_inv_dic[dim_Ti-1]
    else:
        print("generate new: " + str(dim_Ti-1))
        sigma_T = get_sigma_Ti(dim_Ti-1)
        sigma_T_inv = np.linalg.inv(sigma_T)
        sigma_T_dic[dim_Ti-1] = sigma_T
        sigma_T_inv_dic[dim_Ti-1] = sigma_T_inv
        
    return dim_Ti-2, np.einsum("ab,bc->ac",np.einsum("ab,bc->ac", s_m_mu.T,sigma_T_inv), s_m_mu), sigma_T_dic, sigma_T_inv_dic

def get_approx_sigma(latents_list, sigma_T_dic, sigma_T_inv_dic):
    time_len_sum = 0
    s_s_s = 0
    for i in range(len(latents_list)):
        if i % 100 == 0 :
            print(str(i+1) + "/" + str(len(latents_list)))
        temp_len, temp_sigma, sigma_T_dic, sigma_T_inv_dic = get_approx_sigma_T(np.array(latents_list[i]), sigma_T_dic, sigma_T_inv_dic)
        time_len_sum += temp_len
        s_s_s += temp_sigma
    return s_s_s / time_len_sum, sigma_T_dic, sigma_T_inv_dic


def get_bbscore_single(latents_np, sigma_hat, sigma_hat_inv, sigma_T_dic, sigma_T_inv_dic, encoder_dim=32):
    
    temp_len = latents_np.shape[0] - 2
    s_m_mu = get_s_m_mu(latents_np)
    if temp_len+1 in sigma_T_dic.keys():
        sigma_T = sigma_T_dic[temp_len+1]
        sigma_T_inv = sigma_T_inv_dic[temp_len+1]
    else:
        print("generate new: " + str(temp_len+1))
        sigma_T = get_sigma_Ti(temp_len+1)
        sigma_T_inv = np.linalg.inv(sigma_T)
        sigma_T_dic[temp_len+1] = sigma_T
        sigma_T_inv_dic[temp_len+1] = sigma_T_inv
    
    
    bbscore = - encoder_dim * temp_len / 2 * np.log(2 * np.pi) - encoder_dim / 2 * np.log(np.linalg.det(sigma_T)) - temp_len * np.log(np.linalg.det(sigma_hat)) / 2 
    bbscore = bbscore - np.trace(np.einsum("ab,bc->ac", sigma_hat_inv, np.einsum("ab,bc->ac",np.einsum("ab,bc->ac", s_m_mu.T,sigma_T_inv), s_m_mu)))
    
    bbscore_pvalue = np.trace(np.einsum("ab,bc->ac", sigma_hat_inv, np.einsum("ab,bc->ac",np.einsum("ab,bc->ac", s_m_mu.T,sigma_T_inv), s_m_mu)))
    bbscore_pvalue = bbscore_pvalue / (temp_len+1) / encoder_dim
    #for i in range(temp_len):
    #    alpha_i = (i+1) * (temp_len - i) / (temp_len + 1)
    #    sigma_i = alpha_i * sigma_hat
    #    bbscore += np.log(np.linalg.det(sigma_i)) + np.einsum("ab,bc->ac",np.einsum("ab,bc->ac", s_m_mu[i:i+1,:],np.linalg.inv(sigma_i)), s_m_mu[i:i+1,:].T) + encoder_dim * np.log(2*np.pi)
    
    return bbscore, bbscore_pvalue, latents_np.shape[0], sigma_T_dic, sigma_T_inv_dic



def get_bbscore_set(latents_set, sigma_hat, sigma_hat_inv, sigma_T_dic, sigma_T_inv_dic, encoder_dim=32):
    result = []
    result_pvalue = []
    len_list = []
    
    for i in range(len(latents_set)):
        bbscore, bbscore_pvalue, article_len, sigma_T_dic, sigma_T_inv_dic = get_bbscore_single(np.array(latents_set[i]), sigma_hat, sigma_hat_inv, sigma_T_dic, sigma_T_inv_dic, encoder_dim=32)
        result.append(bbscore)
        result_pvalue.append(bbscore_pvalue)
        len_list.append(article_len)
    return result, result_pvalue, len_list, sigma_T_dic, sigma_T_inv_dic


def get_temp_result(path, sigma_hat, sigma_hat_inv, sigma_T_dic, sigma_T_inv_dic, encoder_dim=32):
    #path = train_latents_dir + str(items_train[0])
    with open(path, 'rb') as fp:
        data = pickle.load(fp)

    data_result, data_result_pvalue, len_list, sigma_T_dic, sigma_T_inv_dic = get_bbscore_set(data, sigma_hat, sigma_hat_inv, sigma_T_dic, sigma_T_inv_dic, encoder_dim=encoder_dim)
    temp_total_pairs = len(data_result) - 1
    data_result_arr = abs(np.array(data_result))
    temp_result = np.sum(data_result_arr < data_result_arr[0])
    return temp_result, temp_total_pairs, sigma_T_dic, sigma_T_inv_dic, data_result, data_result_pvalue, len_list

def get_shuffle_score(sigma_hat, sigma_hat_inv, sigma_T_dic, sigma_T_inv_dic, name_before_item, item_list_all):
    pos_all = 0
    total_pairs = 0
    data_all = []
    data_all_pvalue = []
    len_all = []
    for i in range(len(item_list_all)):
        if i % 100 == 0:
            print(str(i+1) + "/" + str(len(item_list_all)))
        path = name_before_item + str(item_list_all[i])
        if not os.path.exists(path):
            print("path doesn't exist: " + str(path))
            continue
        temp_result, temp_total_pairs, sigma_T_dic, sigma_T_inv_dic, data_result, data_result_pvalue, len_list = get_temp_result(path, sigma_hat, sigma_hat_inv, sigma_T_dic, sigma_T_inv_dic, encoder_dim=32)
        pos_all = pos_all + temp_result
        total_pairs = total_pairs + temp_total_pairs
        data_all.append(data_result)
        data_all_pvalue.append(data_result_pvalue)
        len_all.append(len_list)
        
    return pos_all, total_pairs, sigma_T_dic, sigma_T_inv_dic, data_all, data_all_pvalue, len_all