import torch
from torch.utils.data import DataLoader
from proxies import effective_rank, meco, swap, zen, synflow, fisher, grasp, snip, grad_norm, zico
from celery import Celery
import json
import numpy as np
import random
import copy

SEARCH_SPACE = "sss"
DATASET = "cifar10"
NUM_CLASSES = {"cifar10": 10, "cifar100": 100, "ImageNet16-120": 120}
DATASET_PATH = {"cifar10": "", "cifar100": "", "ImageNet16-120": "ImageNet16"}
REG_SWAP_MU = {"tss": 1.5, "sss": 0.7}
REG_SWAP_SIGMA = {"tss": 1.5, "sss": 0.7}

from nats_bench import create
api = create(None, SEARCH_SPACE, fast_mode=True, verbose=False)

from xautodl.models import get_cell_based_tiny_net
from xautodl.datasets import get_datasets

def set_seed(seed):
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)

torch.set_num_threads(1)
app = Celery("tasks", broker="redis://:Xcrdb3rDeEf@172.31.55.178:27010/")

train_set, test_set, xshape, class_num = get_datasets(DATASET, f"~/Documents/Torch_Dataset/{DATASET_PATH[DATASET]}", 0)

statistics = {}

app.conf.broker_transport_options = {'visibility_timeout': 36000}

@app.task(acks_late=True, reject_on_worker_lost=True)
def train(model_id):
    epochs = 200 if SEARCH_SPACE == "tss" else 90
    statistics[model_id] = {
        epoch: {
            "accuracy": api.get_more_info(model_id, DATASET, hp=epochs, is_random=False)["test-accuracy"],
            "num_param": api.get_cost_info(model_id, DATASET)["params"],
            "flops":  api.get_cost_info(model_id, DATASET)["flops"]
        }
        for epoch in [0, 1, 3, 5, 10]
    }

    config = api.get_net_config(model_id, DATASET)
    model = get_cell_based_tiny_net(config)
    set_seed(1337)
    train_loader = DataLoader(train_set, batch_size=256, num_workers=0, shuffle=True)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=0.0005, nesterov=True)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
    loss_function = torch.nn.CrossEntropyLoss()
    epochs_to_evaluate = {0, 1, 3, 5, 10}
    for epoch in range(0, 11):
        model.train()
        if epoch != 0:
            for batch_idx, (data, target) in enumerate(train_loader):
                optimizer.zero_grad()
                output = model(data)[1]
                loss = loss_function(output, target)
                loss.backward()
                optimizer.step()
                scheduler.step()
                print(f"{batch_idx} of {epoch} done.")
        if epoch in epochs_to_evaluate:
            for measure in ["effective_rank", "swap", "swap_reg", "meco_opt", "zico", "zen", "synflow", "fisher", "grasp", "snip", "grad_norm"]:
                print(f"measure {measure} done")
                model_for_score = copy.deepcopy(model)
                # create new dataloaders to make sure all models get the same data
                set_seed(1337)
                train_dataloader = DataLoader(train_set, batch_size=1, num_workers=0, shuffle=True)
                # 16 is the batch size used here https://github.com/pym1024/SWAP/blob/main/correlation.py#L26
                swap_dataloader = DataLoader(train_set, batch_size=16, num_workers=0, shuffle=True)
                nas_dataloader = DataLoader(train_set, batch_size=64, num_workers=0, shuffle=True)
                zico_dataloader = DataLoader(train_set, batch_size=128, num_workers=0, shuffle=True)
                if measure == "effective_rank":
                    model_for_score.eval()
                    statistics[model_id][epoch][measure] = effective_rank.get_average_score_effective_rank(model_for_score, train_dataloader, repetitions=32)
                elif measure == "swap":
                    statistics[model_id][epoch][measure] = swap.compute_nas_score(model_for_score, swap_dataloader, 32, regular=False, mu=None, sigma=None)
                elif measure == "swap_reg":
                    statistics[model_id][epoch][measure] = swap.compute_nas_score(model_for_score, swap_dataloader, 32, regular=True, mu=REG_SWAP_MU[SEARCH_SPACE], sigma=REG_SWAP_SIGMA[SEARCH_SPACE])
                elif measure == "meco_opt":
                    inputs, _ = next(iter(train_dataloader))
                    statistics[model_id][epoch][measure] = meco.get_score(model_for_score, inputs, "cpu", "meco_opt")
                elif measure == "zico":
                    statistics[model_id][epoch][measure] = zico.getzico(model_for_score, zico_dataloader, torch.nn.CrossEntropyLoss())
                elif measure == "zen":
                    inputs, _ = next(iter(nas_dataloader))
                    statistics[model_id][epoch][measure] = zen.compute_nas_score(None, model_for_score, 1e-2, inputs.shape[2], 16, 32)["avg_nas_score"]
                elif measure == "synflow":
                    statistics[model_id][epoch][measure] = synflow.compute_nas_score(model_for_score, nas_dataloader)
                elif measure == "fisher":
                    statistics[model_id][epoch][measure] = fisher.compute_nas_score(model_for_score, nas_dataloader)
                elif measure == "grasp":
                    statistics[model_id][epoch][measure] = grasp.compute_nas_score(model_for_score, nas_dataloader, NUM_CLASSES[DATASET])
                elif measure == "snip":
                    statistics[model_id][epoch][measure] = snip.compute_nas_score(model_for_score, nas_dataloader)
                elif measure == "grad_norm":
                    statistics[model_id][epoch][measure] = grad_norm.compute_nas_score(model_for_score, nas_dataloader)
                else:
                    raise NotImplementedError(f"Measure {measure} is not implemented!")

    with open(f"NATSBench/stats_{model_id}.json", "w") as file:
        json.dump(statistics[model_id], file)
    return statistics[model_id]
