import torch
from torch.utils.data import DataLoader
from proxies import effective_rank, meco, swap, zen, synflow, fisher, grasp, snip, grad_norm, zico
from celery import Celery
import json
import numpy as np
import random
from nasbench_pytorch.model import Network as NBNetwork
from nasbench import api
from xautodl.datasets import get_datasets
from fvcore.nn import FlopCountAnalysis

def set_seed(seed):
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)

torch.set_num_threads(1)
app = Celery("tasks", broker="redis://:Xcrdb3rDeEf@172.31.55.178:27010/")
app.conf.task_serializer = "pickle"
app.conf.result_serializer = "pickle"
app.conf.accept_content = ["application/json", "application/x-python-serialize"]

train_set, test_set, xshape, class_num = get_datasets("cifar10", f"~/Documents/Torch_Dataset/", 0)

#nasbench_path = "nasbench_only108.tfrecord"
#nb = api.NASBench(nasbench_path)
#hashes = list(nb.hash_iterator())
#m = nb.get_metrics_from_hash(net_hash)

@app.task(acks_late=True, reject_on_worker_lost=True)
def train(net_hash, m):
    statistics = {}
    ops = m[0]["module_operations"]
    adjacency = m[0]["module_adjacency"]
    model = NBNetwork((adjacency, ops))
    statistics[net_hash] = {
        "accuracy": m[1][108][0]["final_test_accuracy"],
        "num_param": m[0]["trainable_parameters"],
        "flops": FlopCountAnalysis(model, train_set[0][0].unsqueeze(0)).total()
    }

    for measure in ["effective_rank", "swap", "swap_reg", "meco_opt", "zico", "zen", "synflow", "fisher", "grasp", "snip", "grad_norm"]:
        model = NBNetwork((adjacency, ops))
        # create new dataloaders to make sure all models get the same data
        set_seed(1337)
        train_dataloader = DataLoader(train_set, batch_size=1, num_workers=0, shuffle=True)
        # 16 is the batch size used here https://github.com/pym1024/SWAP/blob/main/correlation.py#L26
        swap_dataloader = DataLoader(train_set, batch_size=16, num_workers=0, shuffle=True)
        nas_dataloader = DataLoader(train_set, batch_size=64, num_workers=0, shuffle=True)
        zico_dataloader = DataLoader(train_set, batch_size=128, num_workers=0, shuffle=True)
        if measure == "effective_rank":
            model.eval()
            statistics[net_hash][measure] = effective_rank.get_average_score_effective_rank(model, train_dataloader, repetitions=32)
        elif measure == "swap":
            statistics[net_hash][measure] = swap.compute_nas_score(model, swap_dataloader, 32, regular=False, mu=None, sigma=None)
        elif measure == "swap_reg":
            statistics[net_hash][measure] = swap.compute_nas_score(model, swap_dataloader, 32, regular=True, mu=40, sigma=40)
        elif measure == "meco_opt":
            inputs, _ = next(iter(train_dataloader))
            statistics[net_hash][measure] = meco.get_score(model, inputs, "cpu", "meco_opt")
        elif measure == "zico":
            statistics[net_hash][measure] = zico.getzico(model, zico_dataloader, torch.nn.CrossEntropyLoss())
        elif measure == "zen":
            inputs, _ = next(iter(nas_dataloader))
            statistics[net_hash][measure] = zen.compute_nas_score(None, model, 1e-2, inputs.shape[2], 16, 32)["avg_nas_score"]
        elif measure == "synflow":
            statistics[net_hash][measure] = synflow.compute_nas_score(model, nas_dataloader)
        elif measure == "fisher":
            statistics[net_hash][measure] = fisher.compute_nas_score(model, nas_dataloader)
        elif measure == "grasp":
            statistics[net_hash][measure] = grasp.compute_nas_score(model, nas_dataloader, 10)
        elif measure == "snip":
            statistics[net_hash][measure] = snip.compute_nas_score(model, nas_dataloader)
        elif measure == "grad_norm":
            statistics[net_hash][measure] = grad_norm.compute_nas_score(model, nas_dataloader)
        else:
            raise NotImplementedError(f"Measure {measure} is not implemented!")

    with open(f"NASBench/stats_{net_hash}.json", "w") as file:
        json.dump(statistics[net_hash], file)
    return statistics[net_hash]
