import torch
from torch.utils.data import DataLoader
from proxies import effective_rank, meco, swap, zen, synflow, fisher, grasp, snip, grad_norm, zico
from celery import Celery
import json
import numpy as np
import random

from api import TransNASBenchAPI
from tnb101.model_builder import *
from utils import *
from collections import namedtuple

TASK = "autoencoder"

# Same configuration used in naslib: https://github.com/automl/NASLib/blob/8cb5d2ba1e29784de43039d9824c68e88fb1a1da/naslib/search_spaces/transbench101/graph.py#L62
NUM_CLASSES = {
    "autoencoder": 10,
    "class_object": 100,
    "class_scene": 63,
    "jigsaw": 1000,
    "normal": 10,
    "segmentsemantic": 10,
}

MEASURES = {
    "autoencoder": "test_l1_loss",
    "class_object": "test_top1",
    "class_scene": "test_top1",
    "jigsaw": "test_top1",
    "normal": "test_l1_loss",
    "segmantation": "test_mIoU"
}


Config = namedtuple("config", ["data", "dataset", "batch_size", "seed", "train_portion"])
config = Config(
    data="~/Torch_Dataset/",
    dataset="class_object",
    batch_size=1,
    seed=1337,
    train_portion=0.7
)
train_queue, valid_queue, test_queue, train_transform, valid_transform = get_train_val_loaders(config)

api = TransNASBenchAPI("transnas-bench_v10141024.pth")


def set_seed(seed):
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)

torch.set_num_threads(1)
app = Celery("tasks", broker="redis://:Xcrdb3rDeEf@172.31.55.180:27010/")

statistics = {}

@app.task()
def train(model_id):
    arch = api.index2arch(model_id)

    statistics[model_id] = {
        "loss": api.get_single_metric(arch, TASK, MEASURES[TASK], mode='best'),
        "num_param": api.get_model_info(arch, TASK, "model_params"),
        "flops":  api.get_model_info(arch, TASK, "model_FLOPs"),
    }

    for measure in ["effective_rank", "swap", "swap_reg", "meco_opt", "zico", "zen", "synflow", "fisher", "grasp", "snip", "grad_norm"]:
        model = create_model(arch, TASK)
        # create new dataloaders to make sure all models get the same data
        set_seed(1337)
        train_dataloader, val_loader, test_loader, _, _ = get_train_val_loader(batch_size=1)
        # 16 is the batch size used here https://github.com/pym1024/SWAP/blob/main/correlation.py#L26
        swap_dataloader, val_loader, test_loader, _, _ = get_train_val_loader(batch_size=16)
        nas_dataloader, val_loader, test_loader, _, _ = get_train_val_loader(batch_size=64)
        zico_dataloader, val_loader, test_loader, _, _ = get_train_val_loader(batch_size=128)
        if measure == "effective_rank":
            model.eval()
            statistics[model_id][measure] = effective_rank.get_average_score_effective_rank(model, train_dataloader, repetitions=32)
        elif measure == "swap":
            statistics[model_id][measure] = swap.compute_nas_score(model, swap_dataloader, 32, regular=False, mu=None, sigma=None)
        elif measure == "swap_reg":
            statistics[model_id][measure] = swap.compute_nas_score(model, swap_dataloader, 32, regular=True, mu=69, sigma=69)
        elif measure == "meco_opt":
            inputs, _ = next(iter(train_dataloader))
            statistics[model_id][measure] = meco.get_score(model, inputs, "cpu", "meco_opt")
        elif measure == "zico":
            statistics[model_id][measure] = zico.getzico(model, zico_dataloader, torch.nn.CrossEntropyLoss())
        elif measure == "zen":
            inputs, _ = next(iter(nas_dataloader))
            statistics[model_id][measure] = zen.compute_nas_score(None, model, 1e-2, inputs.shape[2], 16, 32)["avg_nas_score"]
        elif measure == "synflow":
            statistics[model_id][measure] = synflow.compute_nas_score(model, nas_dataloader)
        elif measure == "fisher":
            statistics[model_id][measure] = fisher.compute_nas_score(model, nas_dataloader)
        elif measure == "grasp":
            statistics[model_id][measure] = grasp.compute_nas_score(model, nas_dataloader, NUM_CLASSES[TASK])
        elif measure == "snip":
            statistics[model_id][measure] = snip.compute_nas_score(model, nas_dataloader)
        elif measure == "grad_norm":
            statistics[model_id][measure] = grad_norm.compute_nas_score(model, nas_dataloader)
        else:
            raise NotImplementedError(f"Measure {measure} is not implemented!")

    with open(f"Transnasbench/stats_{model_id}.json", "w") as file:
        json.dump(statistics[model_id], file)
    return statistics[model_id]
