import pickle
from collections import defaultdict

import numpy as np

from tqdm import tqdm
import wandb
from accelerate import Accelerator
import pyrallis

from dataclasses import dataclass, field

from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
import mauve
from fast_bleu import SelfBLEU

from ddlm.validation.dist import distinct_n
from ddlm.validation.peplexity import count_ar_nll
from ddlm.validation.zipf import zipfs_coefficient
import logging

logger = logging.getLogger("validation")
logger.setLevel(logging.INFO)

@dataclass
class ValidateConfig:
    run_name: str
    log_interval: int = 10
    validate_only_last_step: bool = False
    validate_only_n: int = None

def download_from_wandb(path: str, run):
  for file in run.files():
    if file.name==path:
      file.download(replace=True)
      with open(path, 'rb') as f:
        print("Downloaded file from wandb")
        return pickle.load(f)
    
def get_mauve(texts, full_texts):
    mauve_score = mauve.compute_mauve(p_text=texts, q_text=full_texts, 
                                      batch_size=32,device_id=0).mauve
    return mauve_score

def get_dist(tokenized):
    dist_1, dist_2, dist_3 = distinct_n(tokenized)
    return np.array(dist_1), np.array(dist_2), np.array(dist_3)

def get_arnll(texts, accelerator, ar_nll_model, ar_nll_tokenizer):
    ar_nll_model = accelerator.prepare(ar_nll_model)
    ar_nll_model.eval()

    ar_nll = count_ar_nll(model=ar_nll_model, tokenizer=ar_nll_tokenizer, 
                          generations=list(texts), accelerator=accelerator, 
                          batch_size=1, average=False)

    return ar_nll

@pyrallis.wrap()
def validate(config: ValidateConfig):
    wandb.init(
        name=f"ssd validation",
        project="PROJECT_NAME",
        resume=False,
    )

    api = wandb.Api()
    run = api.run(config.run_name)
    
    measure_mauve = run.config["prefix_length"] == 32

    ssd_tokenizer = AutoTokenizer.from_pretrained("xhan77/ssdlm")
    tokenizer = ssd_tokenizer

    accelerator = Accelerator(mixed_precision="bf16")
    ar_nll_tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-2.7B")
    ar_nll_model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-2.7B")

    TEXTS_OVER_STEPS = download_from_wandb("generated_texts.pickle", run)
    if config.validate_only_last_step:
        TEXTS_OVER_STEPS = [TEXTS_OVER_STEPS[-1]]
    
    if config.validate_only_n:
        TEXTS_OVER_STEPS = [t[:config.validate_only_n] for t in TEXTS_OVER_STEPS]

    if measure_mauve:
        dataset = load_dataset("allenai/c4", data_files=["en/c4-validation.00000-of-00008.json.gz"])
        dataset = dataset.remove_columns(["timestamp", "url"])["train"]
        full_texts = dataset[:len(TEXTS_OVER_STEPS[0])]["text"]
        full_texts = [tokenizer.decode(tokenizer.encode(t)[:64]) for t in full_texts]

    metrics = {k: [] for k in ["mauve", "dist_1", "dist_2", "dist_3", "ar_nll", "zipf", "self_bleu"]}
    for step, TEXTS in tqdm(enumerate(TEXTS_OVER_STEPS)):
        if step == 743:
            continue
        
        print("Step:", step)
        TOKENIZED = tokenizer(list(TEXTS), max_length=64, add_special_tokens=False, 
                              return_tensors='pt', truncation=True, padding=True)['input_ids']

        # prompts = tokenizer.batch_decode([text[:32] for text in TOKENIZED])
        continuations = tokenizer.batch_decode([text[32:64] for text in TOKENIZED])

        TOKENIZED_CONTINUATION = tokenizer(continuations)

        if measure_mauve:
            mauve = get_mauve(texts=list(TEXTS), full_texts=full_texts)
            print(f"Mean MAUVE: {np.mean(mauve)}, MAUVE shape: {mauve.shape}")
        else:
            mauve = None
        
        dist_1, dist_2, dist_3 = get_dist(TOKENIZED_CONTINUATION)
        print("Dist_1:", dist_1)
        
        # print("Counting AR-NLL")
        ar_nll = get_arnll(TEXTS, accelerator=accelerator, ar_nll_model=ar_nll_model, 
                           ar_nll_tokenizer=ar_nll_tokenizer)
        
        print(f"Mean AR-NLL: {np.mean(ar_nll)}, AR-NLL shape: {ar_nll.shape}")

        zipf = zipfs_coefficient(tokenized_texts=TOKENIZED.tolist())
        print("Zipf", zipf)
        
        self_bleu = np.array(SelfBLEU(TOKENIZED).get_score()[4])
        print(f"Mean self-BLEU: {np.mean(self_bleu)}, self-BLEU shape: {self_bleu.shape}")

        for key, value in zip(metrics.keys(), 
                              [mauve, dist_1, dist_2, dist_3, ar_nll, zipf, self_bleu]):
            metrics[key].append(value)

        if (step % config.log_interval) == (config.log_interval - 1):
            with open(f"metrics_{step//config.log_interval}.pickle", 'wb') as f:
                pickle.dump(metrics, f)
                wandb.save(f"metrics_{step//config.log_interval}.pickle")
            metrics = {k: [] for k in ["mauve", "dist_1", "dist_2", "dist_3", "ar_nll", "zipf", "self_bleu"]}
            print(f"Pickled at step {step}")

if __name__ == "__main__":
    validate()
    # validate(ValidateConfig("tlab/ee-diffusion/wzw6hxzi"))
