import torch
from torch.utils.data import DataLoader
from proxies import effective_rank, meco, swap, zen, synflow, fisher, grasp, snip, grad_norm, zico
from celery import Celery
import json
import numpy as np
import random

SEARCH_SPACE = "sss"
DATASET = "cifar10"
NUM_CLASSES = {"cifar10": 10, "cifar100": 100, "ImageNet16-120": 120}
DATASET_PATH = {"cifar10": "", "cifar100": "", "ImageNet16-120": "ImageNet16"}
REG_SWAP_MU = {"tss": 1.5, "sss": 0.7}
REG_SWAP_SIGMA = {"tss": 1.5, "sss": 0.7}

from nats_bench import create
api = create(None, SEARCH_SPACE, fast_mode=True, verbose=False)

from xautodl.models import get_cell_based_tiny_net
from xautodl.datasets import get_datasets

def set_seed(seed):
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)

torch.set_num_threads(1)
app = Celery("tasks", broker="redis://:Xcrdb3rDeEf@172.31.55.178:27010/")

train_set, test_set, xshape, class_num = get_datasets(DATASET, f"~/Documents/Torch_Dataset/{DATASET_PATH[DATASET]}", 0)

statistics = {}

@app.task(acks_late=True, reject_on_worker_lost=True)
def train(model_id):
    epochs = 200 if SEARCH_SPACE == "tss" else 90
    statistics[model_id] = {
        "accuracy": api.get_more_info(model_id, DATASET, hp=epochs, is_random=False)["test-accuracy"],
        "num_param": api.get_cost_info(model_id, DATASET)["params"],
        "flops":  api.get_cost_info(model_id, DATASET)["flops"],
    }

    config = api.get_net_config(model_id, DATASET)
    for measure in ["effective_rank", "swap", "swap_reg", "meco_opt", "zico", "zen", "synflow", "fisher", "grasp", "snip", "grad_norm"]:
        model = get_cell_based_tiny_net(config)
        # create new dataloaders to make sure all models get the same data
        set_seed(1337)
        train_dataloader = DataLoader(train_set, batch_size=1, num_workers=0, shuffle=True)
        # 16 is the batch size used here https://github.com/pym1024/SWAP/blob/main/correlation.py#L26
        swap_dataloader = DataLoader(train_set, batch_size=16, num_workers=0, shuffle=True)
        nas_dataloader = DataLoader(train_set, batch_size=64, num_workers=0, shuffle=True)
        zico_dataloader = DataLoader(train_set, batch_size=128, num_workers=0, shuffle=True)
        if measure == "effective_rank":
            model.eval()
            statistics[model_id][measure] = effective_rank.get_average_score_effective_rank(model, train_dataloader, repetitions=32)
        elif measure == "swap":
            statistics[model_id][measure] = swap.compute_nas_score(model, swap_dataloader, 32, regular=False, mu=None, sigma=None)
        elif measure == "swap_reg":
            statistics[model_id][measure] = swap.compute_nas_score(model, swap_dataloader, 32, regular=True, mu=REG_SWAP_MU[SEARCH_SPACE], sigma=REG_SWAP_SIGMA[SEARCH_SPACE])
        elif measure == "meco_opt":
            inputs, _ = next(iter(train_dataloader))
            statistics[model_id][measure] = meco.get_score(model, inputs, "cpu", "meco_opt")
        elif measure == "zico":
            statistics[model_id][measure] = zico.getzico(model, zico_dataloader, torch.nn.CrossEntropyLoss())
        elif measure == "zen":
            inputs, _ = next(iter(nas_dataloader))
            statistics[model_id][measure] = zen.compute_nas_score(None, model, 1e-2, inputs.shape[2], 16, 32)["avg_nas_score"]
        elif measure == "synflow":
            statistics[model_id][measure] = synflow.compute_nas_score(model, nas_dataloader)
        elif measure == "fisher":
            statistics[model_id][measure] = fisher.compute_nas_score(model, nas_dataloader)
        elif measure == "grasp":
            statistics[model_id][measure] = grasp.compute_nas_score(model, nas_dataloader, NUM_CLASSES[DATASET])
        elif measure == "snip":
            statistics[model_id][measure] = snip.compute_nas_score(model, nas_dataloader)
        elif measure == "grad_norm":
            statistics[model_id][measure] = grad_norm.compute_nas_score(model, nas_dataloader)
        else:
            raise NotImplementedError(f"Measure {measure} is not implemented!")

    with open(f"NATSBench/stats_{model_id}.json", "w") as file:
        json.dump(statistics[model_id], file)
    return statistics[model_id]
