import torch
from torch.utils.data import DataLoader
import effective_rank
from celery import Celery
import json
import numpy as np
import random

SEARCH_SPACE = "sss"
DATASET = "cifar10"
NUM_CLASSES = {"cifar10": 10, "cifar100": 100, "ImageNet16-120": 120}
DATASET_PATH = {"cifar10": "", "cifar100": "", "ImageNet16-120": "ImageNet16"}

from nats_bench import create
api = create(None, SEARCH_SPACE, fast_mode=True, verbose=False)

from xautodl.models import get_cell_based_tiny_net
from xautodl.datasets import get_datasets

def set_seed(seed):
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)

torch.set_num_threads(1)
app = Celery("tasks", broker="redis://:Xcrdb3rDeEf@172.31.55.178:27010/")

train_set, test_set, xshape, class_num = get_datasets(DATASET, f"~/Documents/Torch_Dataset/{DATASET_PATH[DATASET]}", 0)

statistics = {}

@app.task(acks_late=True, reject_on_worker_lost=True)
def train(model_id):
    epochs = 200 if SEARCH_SPACE == "tss" else 90
    statistics[model_id] = {
        "accuracy": api.get_more_info(model_id, DATASET, hp=epochs, is_random=False)["test-accuracy"],
        "num_param": api.get_cost_info(model_id, DATASET)["params"],
        "flops":  api.get_cost_info(model_id, DATASET)["flops"]
    }

    config = api.get_net_config(model_id, DATASET)
    model = get_cell_based_tiny_net(config)
    set_seed(1337)
    for measure in ["effective_rank"]:
        # create new dataloaders to make sure all models get the same data
        set_seed(1337)
        train_dataloader = DataLoader(train_set, batch_size=1, num_workers=0, shuffle=True)

        if measure == "effective_rank":
            model.eval()
            statistics[model_id][measure] = effective_rank.get_average_score_effective_rank(model, train_dataloader, repetitions=32)

    with open(f"NATSBench/stats_{model_id}.json", "w") as file:
        json.dump(statistics[model_id], file)
    return statistics[model_id]
