import torch
from torch.utils.data import DataLoader
import effective_rank
import meco
from celery import Celery
import json
import random
import numpy as np
from datasets import load_dataset
from transformers import GPT2TokenizerFast



def tokenize_function(examples):
    return tokenizer(examples["text"], truncation=True, padding="max_length")

dataset = load_dataset("openwebtext", split="train[:1000]")
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
tokenizer.add_special_tokens({"pad_token": "<pad>"})
tokenized_dataset = dataset.map(tokenize_function, batched=True)
tokenized_dataset.set_format("torch")

def set_seed(seed):
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)

torch.set_num_threads(4)
app = Celery("tasks", broker="redis://:Xcrdb3rDeEf@172.31.55.178:27010/")

from hwgpt.api import HWGPT
api = HWGPT(search_space="s", use_supernet_surrogate=False)

statistics = {}

import time

@app.task()
def train(arch, model_id):
    api.set_arch(arch)
    statistics[model_id] = {
        "perplexity": api.query(metric="perplexity",predictor="mlp")["perplexity"],
        "num_param": api.get_params(),
        "flops":  api.get_flops()
    }
    set_seed(1337)
    model = api.create_model()
    for measure in ["effective_rank", "meco_opt"]:
        # create new dataloaders to make sure all models get the same data
        set_seed(1337)
        train_dataloader = torch.utils.data.DataLoader(tokenized_dataset, batch_size=1)

        if measure == "effective_rank":
            model.eval()
            statistics[model_id][measure] = effective_rank.get_average_score_effective_rank(model, train_dataloader, repetitions=1)
        elif measure == "meco_opt":
            inputs = next(iter(train_dataloader))["input_ids"]
            while torch.any(inputs > 50254):
                inputs = next(iter(train_dataloader))["input_ids"]
            statistics[model_id][measure] = meco.get_score(model, inputs, "cpu", "meco_opt")

    with open(f"HWGPTBench/stats_{model_id}.json", "w") as file:
        json.dump(statistics[model_id], file)
    return statistics[model_id]
