from collections import defaultdict
import pickle
import pyrallis
import torch
from tqdm.auto import tqdm, trange
from dataclasses import dataclass, field
from datasets import load_dataset
from ssd.configs import GenerateConfig, UnconstrainedGenerationConfig, ControlledGenerationConfig
from ssd.early_exit import EarlyExit
from ssd.decode import decode
import wandb
import numpy as np

# import logging
# logging.basicConfig(level=logging.DEBUG)
# log = logging.getLogger("__main__")

@dataclass
class GenerateConfig:
    num_texts: int = field(default=256)
    num_steps: int = field(default=10)
    batch_size: int = field(default=256)
    prefix_length: int = field(default=0)
    max_length: int = field(default=64)
    seed: int = field(default=2022)

@pyrallis.wrap()
def generate(
    config: GenerateConfig
):  
    wandb.init(
        name=f"ssd generation",
        project="ee-diffusion",
        resume=False,
    )

    if config.prefix_length > 0:
        dataset = load_dataset("allenai/c4", data_files=["en/c4-validation.00000-of-00008.json.gz"])
        dataset = dataset.remove_columns(["timestamp", "url"])["train"]
        texts = dataset[:config.num_texts]["text"]
    else:
        texts = ["\n\n"] * config.num_texts          
    
    print("Collected texts")

    unconstrained_config = UnconstrainedGenerationConfig(
        total_t=config.num_steps, decode_total_gen_len=config.max_length - config.prefix_length, 
        max_seq_length=config.max_length, seed=config.seed)
    
    print("Created config")
    SAMPLED_SEQUENCES = []
    EARLY_EXIT_HISTORY = defaultdict(list)
    torch.set_grad_enabled(False)
    for i in trange(len(texts)//config.batch_size):
        texts_batch = texts[i*config.batch_size:(i+1)*config.batch_size]
        print(f"Starting generation of batch {i}")
        print(f"texts_batch size: {len(texts_batch)}")
        sampled_sequences, early_exit_history = decode(unconstrained_config, texts_batch, 
                                                       batch_size=config.batch_size, 
                                                       context_size=config.prefix_length)
        
        SAMPLED_SEQUENCES.append(sampled_sequences)
        for key in early_exit_history.keys():
            EARLY_EXIT_HISTORY[key] += early_exit_history[key]
        print(f"Generated batch {i}")
    
    with open(f"early_exit_history.pickle", 'wb') as f:
        pickle.dump(EARLY_EXIT_HISTORY, f)
        wandb.save(f"early_exit_history.pickle")



    wandb.finish()

if __name__ == '__main__':
    generate()
    # generate(GenerateConfig("hihi haha"))