import os
import pickle
import random
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Dict

import lovely_tensors as lt
import mauve
import numpy as np
import pyrallis
import torch
import wandb
from accelerate import Accelerator
from ddlm.modeling.diffusion import (
    DiffusionConfig, DiffusionTransformer)
from ddlm.sampler.early_exit import (
    EntropyStrategy, FixedStrategy, LogStrategy, NoStrategy, PatienceStrategy, KLDivStrategy)
from ddlm.validation.dist import distinct_n
from ddlm.validation.entropy import \
    token_entropy
from ddlm.validation.generate import \
    generate_with_conditioning_mask
from ddlm.validation.load_model import \
    get_model_from_run
from ddlm.validation.peplexity import \
    count_ar_nll
from ddlm.validation.zipf import \
    zipfs_coefficient
from datasets import load_dataset
from fast_bleu import SelfBLEU
from tqdm.auto import trange
from transformers import AutoModelForCausalLM, AutoTokenizer

lt.monkey_patch()

@dataclass
class GenerateConfig:
    run_name: str
    tokenizer_name: str = field(default="c-tokenizer")
    step: int = field(default=-1)
    max_length: int = field(default=64)
    batch_size: int = field(default=512)
    rho: float = field(default=1.)
    num_steps: int = field(default=50)
    fp16: bool = field(default=True)
    use_cache: bool = field(default=False)
    use_time_wrapping: bool = field(default=True)
    t_max: float = field(default=-1)
    t_min: float = field(default=1)
    self_conditioning: bool = field(default=True)
    seed: int = field(default=42)
    ar_validation: bool = field(default=False)
    s_churn: float = field(default=0.0)
    num_continuations: int = field(default=5)
    prefix_length: int = field(default=32)
    postfix_length: int = field(default=0)
    simplified_inputs: bool = field(default=False)
    num_examples: int = field(default=5000)
    renormalization: bool = field(default=False)
    positional_embedding_type: str = field(default="rotary")
    num_hidden_layers: int = field(default=8)
    num_attention_heads: int = field(default=8)
    hidden_size: int = field(default=1024)
    initial_noise_scale: float = field(default=1.)
    timedelta: float = field(default=0.)
    interpolate: bool = field(default=False)
    strategy: str = field(default="no_strategy")
    patience: int = field(default=int(1e3))
    patience_frac: float = field(default=-1.)
    artifact: bool = field(default=False)
    entropy_threshold: float = field(default=1.)
    fixed_strategy_threshold: int = field(default=100)
    kldiv_threshold: float = field(default=0.1)
    outputs_sweep_name: str = field(default="")
    download_ouputs: bool = field(default=False)

def get_t_max_from_run(run) -> float:
    return run.config["t_max"]

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def set_deterministic_mode(seed):
    set_seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)

    torch.backends.deterministic = True
    torch.backends.benchmark = False

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


@pyrallis.wrap()
def generate(
    config: GenerateConfig
):
    if config.patience_frac > 0:
        config.patience = int(config.num_steps * config.patience_frac)

    wandb.init(
        name=f"{config.run_name}",
        project="ee-diffusion",
        resume=False,
    )
    # wandb.mark_preempting()
    set_deterministic_mode(config.seed)
    api = wandb.Api()

    accelerator = Accelerator(mixed_precision="fp16" if config.fp16 else None)

    conditioning_mask = torch.ones(config.max_length, dtype=torch.bool)
    conditioning_mask[config.prefix_length:config.max_length-config.postfix_length] = 0

    tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name)

    if tokenizer.pad_token is None:
        tokenizer.add_special_tokens({'pad_token': '[PAD]'})

    if tokenizer.mask_token is None:
        tokenizer.add_special_tokens({'mask_token': '[MASK]'})

    dataset = load_dataset("allenai/c4", data_files=["en/c4-validation.00000-of-00008.json.gz"])
    dataset = dataset.remove_columns(["timestamp", "url"])["train"]
    full_texts = dataset[:config.num_examples]["text"]
    prompts = [tokenizer.decode(tokenizer.encode(t)[:config.prefix_length]) for t in full_texts]
    full_texts = [tokenizer.decode(tokenizer.encode(t)[:config.max_length]) for t in full_texts]

    if config.strategy == "log":
        with open(f"prompts.pickle", 'wb') as f:
            pickle.dump(prompts, f)
            wandb.save(f"prompts.pickle")

        with open(f"full_texts.pickle", 'wb') as f:
            pickle.dump(full_texts, f)
            wandb.save(f"full_texts.pickle")
    
    desired_step = 1000000 if config.step == -1 else config.step

    model_config = DiffusionConfig(
        num_embeddings=len(tokenizer),
        max_position_embeddings=512,
        positional_embedding_type=config.positional_embedding_type,
        num_attention_heads=config.num_attention_heads,
        num_hidden_layers=config.num_hidden_layers,
        hidden_size=config.hidden_size
    )
    if not config.ar_validation:
        run = api.run(config.run_name)
        if desired_step > 0:
            model, time_wrapping = get_model_from_run(model_config, run=run, step=desired_step, artifact=config.artifact)
        else:
            model, time_wrapping = DiffusionTransformer(config=model_config), None
        if not config.ar_validation:
            if time_wrapping is not None:
                model, time_wrapping = accelerator.prepare(model, time_wrapping)
            else:
                model = accelerator.prepare(model)

    
    if config.strategy == "patience":
        strategy = PatienceStrategy(patience=config.patience)
    elif config.strategy == "entropy":
        min_step = int(0 * config.num_steps)
        strategy = EntropyStrategy(threshold=config.entropy_threshold, min_step=min_step)
    elif config.strategy == "fixed":
        strategy = FixedStrategy(threshold=config.fixed_strategy_threshold)
    elif config.strategy == "log":
        strategy = LogStrategy()
    elif config.strategy == "kldiv":
        min_step = int(0.1 * config.num_steps)
        strategy = KLDivStrategy(threshold=config.kldiv_threshold, min_step=min_step)
    else:
        strategy = NoStrategy()

    OBSERVED_STEPS = []
    OUTPUTS_STEPS = []
    model.eval()
    artifact = None
    for i in trange(config.num_continuations):
        outputs_steps, metrics = generate_with_conditioning_mask(
            model=model,
            texts=full_texts,
            conditioning_mask=conditioning_mask,
            tokenizer=tokenizer,
            time_wrapping=time_wrapping,
            num_steps=config.num_steps,
            rho=config.rho,
            length=config.max_length,
            batch_size=config.batch_size,
            device=accelerator.device,
            use_time_wrapping=config.use_time_wrapping,
            t_max=config.t_max if config.t_max > 0 else get_t_max_from_run(run),
            self_conditioning=config.self_conditioning,
            s_churn=config.s_churn,
            simplified_inputs=config.simplified_inputs,
            renormalization=config.renormalization,
            initial_noise_scale=config.initial_noise_scale,
            timedelta=config.timedelta,
            artifact=artifact,
            interpolate=config.interpolate,
            strategy=strategy,
            outputs_sweep_name=config.outputs_sweep_name,
            download_ouputs=config.download_ouputs,
            continuation_number=i
        )

        OUTPUTS_STEPS.append(outputs_steps)
        torch.cuda.empty_cache()

        OBSERVED_STEPS.append(metrics["observed_steps"])

    del accelerator

    torch.cuda.empty_cache()
    accelerator = Accelerator(mixed_precision="fp16" if config.fp16 else None)
    ar_nll_tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-2.7B")
    ar_nll_model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-2.7B")
    ar_nll_model = accelerator.prepare(ar_nll_model)
    ar_nll_model.eval()

    ar_nll_data = count_ar_nll(model=ar_nll_model, tokenizer=ar_nll_tokenizer, 
                               generations=full_texts, accelerator=accelerator, 
                               batch_size=8)
    zipf_data = zipfs_coefficient(tokenized_texts=[tokenizer.encode(t) for t in full_texts])
    self_bleu_data = np.mean(SelfBLEU([tokenizer.encode(t) for t in full_texts]).get_score()[4])
    token_entropy_data = token_entropy(full_texts, tokenizer)
    
    wandb.log({
        "ar_nll_data": ar_nll_data,
        "zipf_data": zipf_data,
        "self_bleu_data": self_bleu_data,
        "token_entropy_data": token_entropy_data
    })

    for continuation_index, outputs_steps in enumerate(OUTPUTS_STEPS):
        metrics_to_pickle = defaultdict(list)
        for step, outputs in enumerate(outputs_steps):
            MAUVE = []
            AR_NLL = []
            SELF_BLEU = []

            texts, tokenized, tokenized_continuation = outputs
            TEXTS = [texts]
            TOKENIZED = [tokenized]
            TOKENIZED_CONTINUATION = [tokenized_continuation]
            
            zipfs_value = zipfs_coefficient(
                tokenized_texts=[t for l_t in TOKENIZED for t in l_t]
            )
            token_entropy_value = token_entropy([t for l_t in TEXTS for t in l_t], tokenizer)

            for tokenized, texts in zip(TOKENIZED, TEXTS):
                ar_nll_value = count_ar_nll(model=ar_nll_model, tokenizer=ar_nll_tokenizer, 
                                            generations=texts, accelerator=accelerator, 
                                            batch_size=1, average=False)
                AR_NLL.append(ar_nll_value)
                self_bleu = SelfBLEU(tokenized)
                # self_bleu_value = np.mean(self_bleu.get_score()[4])
                SELF_BLEU.append(self_bleu.get_score()[4])

                if full_texts is not None and config.prefix_length > 0:
                    with torch.no_grad():
                        mauve_value = mauve.compute_mauve(p_text=texts, q_text=full_texts, batch_size=8, device_id=0).mauve
                        try:
                            mauve_value = mauve_value.detach()
                        except:
                            pass
                    MAUVE.append(mauve_value)
                torch.cuda.empty_cache()

            stats = defaultdict(lambda: [])
            for ex in list(map(list, zip(*TOKENIZED_CONTINUATION))):
                dist_1, dist_2, dist_3 = distinct_n(ex)
                stats["dist_1"] += [dist_1]
                stats["dist_2"] += [dist_2]
                stats["dist_3"] += [dist_3]

            metrics_to_pickle["step"].append(step)
            metrics_to_pickle["observed_steps_array"].append(np.array(OBSERVED_STEPS)/config.batch_size)
            metrics_to_pickle["observed_steps"].append(np.mean(OBSERVED_STEPS)/config.batch_size)
            metrics_to_pickle["texts"].append(TEXTS)
            metrics_to_pickle["ar_nll_array"].append(AR_NLL)
            metrics_to_pickle["self_bleu_array"].append(SELF_BLEU)
            metrics_to_pickle["dist_1_array"].append(stats["dist_1"])
            metrics_to_pickle["dist_2_array"].append(stats["dist_2"])
            metrics_to_pickle["dist_3_array"].append(stats["dist_3"])
            metrics_to_pickle["ar_nll"].append(np.mean(AR_NLL))
            metrics_to_pickle["dist_1"].append(np.mean(stats["dist_1"]))
            metrics_to_pickle["dist_2"].append(np.mean(stats["dist_2"]))
            metrics_to_pickle["dist_3"].append(np.mean(stats["dist_3"]))
            metrics_to_pickle["zipf"].append(zipfs_value)
            metrics_to_pickle["self_bleu"].append(np.mean(SELF_BLEU))

            wandb.log(
                {   
                    "step_": step,
                    "observed_steps": np.mean(OBSERVED_STEPS)/config.batch_size,
                    "ar_nll": np.mean(AR_NLL),
                    "ar_nll_std": np.std(AR_NLL),
                    "dist_1": np.mean(stats["dist_1"]),
                    "dist_2": np.mean(stats["dist_2"]),
                    "dist_3": np.mean(stats["dist_3"]),
                    "zipf": zipfs_value,
                    "self_bleu": np.mean(SELF_BLEU),
                    "self_bleu_std": np.std(SELF_BLEU),
                    "token_entropy": token_entropy_value
                }
            )
            if full_texts is not None and config.prefix_length > 0 and len(MAUVE) > 0:
                metrics_to_pickle["mauve"].append(np.mean(MAUVE))
                metrics_to_pickle["mauve_array"].append(MAUVE)
                wandb.log(
                    {
                        "mauve": np.mean(MAUVE),
                        "std_mauve": np.std(MAUVE),
                    }
                )
            try:
                wandb.run.log_code()
            except:
                pass

        with open(f"metrics_{continuation_index}.pickle", 'wb') as f:
            pickle.dump(metrics_to_pickle, f)
            wandb.save(f"metrics_{continuation_index}.pickle")
    wandb.finish()

if __name__ == "__main__":
    generate()
