import wandb
import torch
from typing import List
from tqdm import trange


@torch.no_grad()
def count_ar_nll(
    model,
    tokenizer,
    generations,
    device,
    prefixes: List[str] = None,
    batch_size: int = 16,
):
    if tokenizer.pad_token is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id
    model.to(device)
    accumulator = 0
    num_batches = 0
    for start_idx in trange(0, len(generations), batch_size):
        num_batches += 1
        c_gens = generations[start_idx : start_idx + batch_size]
        for i in range(len(c_gens)):
            c_gens[i] = tokenizer.bos_token + c_gens[i]
        encoded = tokenizer(
            c_gens, padding=True, truncation=True, max_length=128, return_tensors="pt"
        )
        encoded["labels"] = encoded["input_ids"].clone()
        encoded["labels"][encoded["labels"] == tokenizer.pad_token_id] = -100
        for k in encoded.keys():
            encoded[k] = encoded[k].to(device)
        loss = model(**encoded).loss.item()
        accumulator += loss * len(c_gens) / len(generations)

    return accumulator


wandb.login()
api = wandb.Api()


def main():
    wandb.init(project='PROJECT_NAME')
    from transformers import AutoModelForCausalLM, AutoTokenizer

    model = AutoModelForCausalLM.from_pretrained('EleutherAI/gpt-neo-2.7B')
    ar_tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neo-2.7B')
    run_name = 'RUN_NAME'
    run = api.run(run_name)

    for f in run.files():
        if f.name == 'early_exit_history.pickle':
            f.download(replace=True)

    import pickle
    import numpy as np

    with open('early_exit_history.pickle', 'rb') as inp:
        metrics = pickle.load(inp)

    def download_from_wandb(path: str, run):
        for file in run.files():
            if file.name == path:
                file.download(replace=True)
                with open(path, 'rb') as f:
                    print("Downloaded file from wandb")
                    return pickle.load(f)

    TEXTS_OVER_STEPS = download_from_wandb("generated_texts.pickle", run)

    metrics['texts'] = TEXTS_OVER_STEPS
    for k in metrics.keys():
        metrics[k] = [v[:200] for v in metrics[k]]

    tokenizer = AutoTokenizer.from_pretrained('xhan77/ssdlm')

    e_s = np.concatenate((np.linspace(0.0, 1, 100), np.linspace(1, 7, 50)))
    mean_steps = []
    all_texts = []
    ar_nlls = []
    unique_tokens = []
    for e_t in e_s:
        ended_mask = np.zeros((len(metrics['entropy'][0]),), dtype=bool)
        mask_storage = []
        tokens_storage = np.zeros((len(metrics['entropy'][0]), 64), dtype=int)
        exit_step = np.zeros((len(metrics['entropy'][0]),))
        last_texts = None
        last_i = None
        texts = ['' for i in range(len(metrics['entropy'][0]))]
        for i, (c_e, c_texts) in enumerate(zip(metrics['entropy'], metrics['texts'])):
            prev_mask = ended_mask.copy()
            prev_mask = np.array(prev_mask)
            ended_mask = np.array(ended_mask)
            ended_mask = (c_e.mean(-1) < e_t) | ended_mask
            changed = prev_mask ^ ended_mask
            mask_storage += [ended_mask.sum()]
            exit_step[changed] = i
            #tokens_storage[changed] = c_tokens[changed]
            idxs_to_fill = np.where(changed)[0]
            for idx in idxs_to_fill:
                texts[idx] = c_texts[idx]
            last_texts = c_texts
            last_i = i
        idxs_to_fill = np.where(~ended_mask)[0]
        for idx in idxs_to_fill:
            texts[idx] = last_texts[idx]
        exit_step[~ended_mask] = last_i
        #tokens_storage[~ended_mask] = last_tokens[~ended_mask]
        #texts = [tokenizer.decode(ts) for ts in tokens_storage]

        all_texts.append(texts)
        mean_steps += [exit_step.mean()]
        ar_nll = count_ar_nll(
            model=model,
            tokenizer=ar_tokenizer,
            generations=texts,
            device='cuda'
        )
        unique_tokens = np.sum([len(np.unique(tokens_storage[i])) // 64 for i in range(len(tokens_storage))])
        ar_nlls.append(ar_nll)
        wandb.log(
            {'entropy_ar': ar_nll, 'mean_steps_e': exit_step.mean(), 'entropy_threshold': e_t, 'unique_e': unique_tokens}

        )

    p_s = np.arange(0, 200, 4)
    mean_steps = []
    all_texts = []
    ar_nlls = []

    for p_t in p_s:
        p_t = p_t
        ended_mask = np.zeros((len(metrics['entropy'][0]),), dtype=bool)
        mask_storage = []
        tokens_storage = np.zeros((len(metrics['entropy'][0]), 64), dtype=int)
        exit_step = np.zeros((len(metrics['entropy'][0]),))
        for i, (c_p, c_texts) in enumerate(zip(metrics['patience'], metrics['texts'])):
            prev_mask = ended_mask.copy()
            ended_mask = (c_p >= p_t) | ended_mask
            changed = prev_mask ^ ended_mask
            mask_storage += [ended_mask.sum()]
            exit_step[changed] = i
            # tokens_storage[changed] = c_tokens[changed]
            idxs_to_fill = np.where(changed)[0]
            for idx in idxs_to_fill:
                texts[idx] = c_texts[idx]
        exit_step[~ended_mask] = i
        idxs_to_fill = np.where(~ended_mask)[0]
        for idx in idxs_to_fill:
            texts[idx] = c_texts[idx]
        # texts = [tokenizer.decode(ts) for ts in tokens_storage]
        all_texts.append(texts)
        mean_steps += [exit_step.mean()]
        ar_nll = count_ar_nll(
            model=model,
            tokenizer=ar_tokenizer,
            generations=texts,
            device='cuda'
        )
        ar_nlls.append(ar_nll)
        unique_tokens = np.sum([len(np.unique(tokens_storage[i])) // 64 for i in range(len(tokens_storage))])
        wandb.log(
            {'patience_ar': ar_nll, 'mean_steps_p': exit_step.mean(), 'patience_threshold': p_t, 'unique_p': unique_tokens}
        )

    kl_s = [-1, 0, 0.00000000001, 0.00000001, 0.0000001, 0.000001, 0.000005, 0.00001, 0.0001, 0.001, 0.01, 0.1, 0.2, 0.4, 0.6, 0.8, 1, 2]
    mean_steps = []
    all_texts = []
    ar_nlls = []

    for kl_t in kl_s:

        ended_mask = np.zeros((len(metrics['entropy'][0]),), dtype=bool)
        mask_storage = []
        tokens_storage = np.zeros((len(metrics['entropy'][0]), 64), dtype=int)
        exit_step = np.zeros((len(metrics['entropy'][0]),))
        for i, (c_kl, c_texts) in enumerate(zip(metrics['kl'], metrics['texts'])):
            if i < 600:
                continue
            prev_mask = ended_mask.copy()
            ended_mask = (c_kl <= kl_t) | ended_mask
            changed = prev_mask ^ ended_mask
            mask_storage += [ended_mask.sum()]
            exit_step[changed] = i
            idxs_to_fill = np.where(changed)[0]
            for idx in idxs_to_fill:
                texts[idx] = c_texts[idx]
            # tokens_storage[changed] = c_tokens[changed]
        exit_step[~ended_mask] = i
        # tokens_storage[~ended_mask] = c_tokens[~ended_mask]
        idxs_to_fill = np.where(~ended_mask)[0]
        for idx in idxs_to_fill:
            texts[idx] = c_texts[idx]
        # texts = [tokenizer.decode(ts) for ts in tokens_storage]
        all_texts.append(texts)
        mean_steps += [exit_step.mean()]
        ar_nll = count_ar_nll(
            model=model,
            tokenizer=ar_tokenizer,
            generations=texts,
            device='cuda'
        )
        ar_nlls.append(ar_nll)
        unique_tokens = np.sum([len(np.unique(tokens_storage[i])) // 64 for i in range(len(tokens_storage))])
        wandb.log(
            {'kl_ar': ar_nll, 'mean_steps_kl': exit_step.mean(), 'kl_threshold': kl_t, 'unique_kl': unique_tokens}
        )

    for step in range(0, 1000, 5):

        texts = [ts for ts in metrics['texts'][step]]
        all_texts.append(texts)
        ar_nll = count_ar_nll(
            model=model,
            tokenizer=ar_tokenizer,
            generations=texts,
            device='cuda'
        )
        # unique_tokens = np.sum([len(np.unique(metrics['tokens'][step][i])) // 64 for i in range(len(metrics['tokens'][step]))])
        wandb.log(
            {'fixed_ar': ar_nll, 'f_step': step}
        )

if __name__ == '__main__':
    main()
