import pickle
from collections import defaultdict

import numpy as np

import wandb
from accelerate import Accelerator
from pyrallis import wrap

from dataclasses import dataclass, field

from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
import mauve
from fast_bleu import SelfBLEU

from ddlm.validation.dist import distinct_n
from ddlm.validation.peplexity import count_ar_nll
from ddlm.validation.zipf import zipfs_coefficient


@dataclass
class ValidateConfig:
    run_name: str
    normalize_lengths: bool = field(default=True)
    measure_mauve: bool = field(default=False)


@wrap()
def validate(config: ValidateConfig):
    wandb.init(project="ssd_val")
    api = wandb.Api()
    run = api.run(config.run_name)
    try:
        config.measure_mauve = bool(run.config["conditioning"])
    except KeyError:
        print(f"Can'y determine run type. Using conditioning={config.measure_mauve}")
    for f in run.files():
        if f.name == "data.pkl":
            f.download(replace=True)
    tokenizer = AutoTokenizer.from_pretrained("c-tokenizer")
    with open("data.pkl", "rb") as inp:
        data = pickle.load(inp)
    prompts = [d[0] for d in data["context_sequences"]]
    continuations = [d[0] for d in data["sampled_sequences"]]

    if config.normalize_lengths:
        fixed_continuations = []
        for i, (p, c) in enumerate(zip(prompts, continuations)):
            if len(p) > 3:
                fixed_continuations += [tokenizer.decode(tokenizer.encode(c)[:32])]
            else:
                if i == 0:
                    print(p)
                    print(c)
                fixed_continuations += [tokenizer.decode(tokenizer.encode(c)[:64])]

        continuations = fixed_continuations
    TEXTS = []
    TOKENIZED = []
    TOKENIZED_CONTINUATION = []
    for p, c in zip(prompts, continuations):
        TEXTS.append(p + c)
        TOKENIZED.append(tokenizer.encode(p + c))
        TOKENIZED_CONTINUATION.append(tokenizer.encode(c))

    mauve_score = None
    l = len(prompts)
    if config.measure_mauve:
        dataset = load_dataset("allenai/c4", data_files=["en/c4-validation.00000-of-00008.json.gz"])
        dataset = dataset.remove_columns(["timestamp", "url"])["train"]
        full_texts = dataset[:l // 5]["text"] 
        full_texts = [tokenizer.decode(tokenizer.encode(t)[:64]) for t in full_texts]
        mauve_score = []
        for i in range(5):
            start_idx = i * (l // 5)
            end_idx = (i + 1) * (l // 5)
            ms = mauve.compute_mauve(p_text=TEXTS[start_idx:end_idx], q_text=full_texts, batch_size=8,
                                     device_id=0).mauve
            mauve_score.append(ms)
        mauve_score = np.mean(mauve_score)

    texts_dist = [TOKENIZED_CONTINUATION[l // 5 * i:l // 5 * (i + 1)] for i in range(5)]
    stats = defaultdict(list)
    for ex in list(map(list, zip(*texts_dist))):
        dist_1, dist_2, dist_3 = distinct_n(ex)
        stats["dist_1"] += [dist_1]
        stats["dist_2"] += [dist_2]
        stats["dist_3"] += [dist_3]

    stats = {k: np.mean(v) for k, v in stats.items()}
    print(stats)

    accelerator = Accelerator(mixed_precision="bf16")
    ar_nll_tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-2.7B")
    ar_nll_model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-2.7B")
    ar_nll_model = accelerator.prepare(ar_nll_model)
    ar_nll_model.eval()

    ar_nll_value = count_ar_nll(model=ar_nll_model, tokenizer=ar_nll_tokenizer, generations=TEXTS,
                                accelerator=accelerator, batch_size=8)

    zipfs_value = zipfs_coefficient(
        tokenized_texts=[t for t in TOKENIZED]
    )

    self_bleu = np.mean(SelfBLEU(TOKENIZED).get_score()[4])

    wandb.log(
        {
            "ar_nll": ar_nll_value,
            "zipf": zipfs_value,
            "self-bleu": self_bleu,
            "mauve": mauve_score
        }
    )
    wandb.log(stats)


if __name__ == "__main__":
    validate()
