import nltk
from nltk import ngrams
from nltk.tokenize import word_tokenize

import math
from collections import defaultdict, Counter
from tqdm import tqdm

import torch

import csv

def read_csv(filename):
    data = []
    with open(filename, "r") as f:
        reader = csv.reader(f, delimiter=",")
        for line in reader:
            data.append(line[0])
    return data

def write_csv(data, filename):
    with open(filename, "w") as f:
        writer = csv.writer(f, delimiter=",")
        for line in data:
            writer.writerow([line])


def read_txt(filename):
    data = []
    with open(filename, "r") as f:
        data.append(f.readlines())
    return data

def seq_rep_n(tokens_list, n=4):
    res = []
    for tokens in tokens_list:
        ngs = [ng for ng in ngrams(tokens, n)]
        counter = Counter(ngs)
        res.append(1.0 - len(counter) / len(ngs))
    return sum(res) / len(res)

def diversity(tokens_list):
    ns = [2,3,4]
    res = []
    for tokens in tokens_list:
        score = 1
        for n in ns:
            ngs = [ng for ng in ngrams(tokens, n)]
            counter = Counter(ngs)
            #print(len(counter) / len(ngs))
            score *= len(counter) / len(ngs)
        res.append(score)
    return sum(res) / len(res)

def build_dict(texts):
    t2i = {}
    for text in texts:
        for tok in text:
            if tok not in t2i:
                t2i[tok] = len(t2i)
    #print("Build vocab size: {}".format(len(t2i)))
    return t2i
 
def tok_repeat_l(hypo_toks, context_len=16):
    hypo = torch.tensor(hypo_toks).long()
    T = hypo.size(0)
    
    # prev_hypo[t, :] = [y_1, y_2, ... , y_t-1, -1 ,-1, ... , -1]
    prev_hypo = hypo.expand(T, T).masked_fill(torch.ones(T, T).triu().bool(), -1)

    # prev_hypo[t, :] = [-1, ... , -1, y_t-k-1, ..., y_t-1, -1 ,-1, ... , -1]
    prev_hypo = prev_hypo.masked_fill(torch.ones(T, T).tril(-context_len).bool(), -1)

    repeat = (hypo[:, None] == prev_hypo)
    has_repeat = repeat.sum(1).gt(0)
    total_repeat = has_repeat.sum()

    return total_repeat * 1.0 / T 

def grep_l(tokens_list, l_cands=[16, 32, 64, 128]):
    dictionary = build_dict(tokens_list)
    res = [0] * len(l_cands)
    for tokens in tokens_list:
        for i, l in enumerate(l_cands):
            res[i] += tok_repeat_l([dictionary[x] for x in tokens], context_len=l)

    return [x * 1.0 / len(tokens_list) for x in res]

def finalize(token_ids, eos_idx=50256, bos_idx=50256):
    clean_token_ids = []

    if token_ids[0] == bos_idx:
        token_ids = token_ids[1:]

    for idx in token_ids:
        if eos_idx != idx:
            clean_token_ids.append(idx)
        else:
            break

    return clean_token_ids