import os
from time import time
import argparse
import torch
from torch.utils.data import DataLoader
from models import supernet_snn
import torchvision
from spikingjelly.clock_driven import functional
import wandb
import torch.nn as nn

#####Search for the best model

parser = argparse.ArgumentParser(description="ImageNet_SNN_Training")
# parser.add_argument("--local_rank", type=int, default=0)
# parser.add_argument("--nums_gpu", type=int, default=4)
parser.add_argument("--names", type=str, default='yun_channel_random')
parser.add_argument("--gpu", type=str, default='0')
parser.add_argument("--method", type=str, default='random')
parser.add_argument("--epoch", type=int, default=200)

args = parser.parse_args()

os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

#####Parameters of the candidate

sweep_config = {
    "name": "sweep_with_launchpad",
    "metric": {"name": "acc", "goal": "maximize"},
    #     "method": "bayes",
    "method": args.method,
    "parameters": {

        "dim1": {
            "values": [48, 64, 80]
        },
        "dim2": {
            "values": [96, 128, 160]
        },
        "dim3": {
            "values": [192, 256, 320]
        },
        "dim4": {
            "values": [384, 512, 640]
        },

        "stage1": {
            "values": [2, 3, 4, 5]
        },
        "stage2": {
            "values": [2, 3, 4, 5]
        },
        "stage3": {
            "values": [2, 3, 4, 5, 6, 7, 8]
        },
        "stage4": {
            "values": [2, 3, 4, 5]
        },
    }

}


sweep_id = wandb.sweep(sweep_config, project=args.names)

###加载模型权重
model = supernet_snn.sup_resnet82__(num_classes=100).cuda()
model.load_state_dict(torch.load("./model_weight/logit_dim_layer_100epoch_128_100_100_a_0.0_b_0.0_lr_0.01.pth"))


def post_bn(net: nn.Module):
    for m in net.modules():
        if isinstance(m, nn.BatchNorm2d):
            m.track_running_stats = True
            # m.reset_running_stats()
            m.register_buffer('running_mean', torch.zeros(m.num_features).cuda())
            m.register_buffer('running_var', torch.ones(m.num_features).cuda())
            m.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long).cuda())


#             m.track_running_stats = False

def reset_bn(net: nn.Module):
    for m in net.modules():
        if isinstance(m, nn.BatchNorm2d):
            m.reset_running_stats()


print(model.bn1.track_running_stats)

print(model.bn1.running_mean)


def acc(model, test_data, data_nums, layer=None, channel=[0,0,0,0], ops=[0, 0, 0, 0]):
    model.eval()
    right = 0
    with torch.no_grad():
        for imgs, label in test_data:
            output = model(imgs.cuda(), layer, channel, ops)
            right = (output.argmax(1) == label.cuda()).sum() + right
            functional.reset_net(model)
    return right / data_nums


###数据集
CIFAR = 100
batchsize = 200

if CIFAR == 10:

    test_dataset = torchvision.datasets.CIFAR10(root='./dataset/cifar10', train=False,
                                                transform=torchvision.transforms.Compose([
                                                    torchvision.transforms.ToTensor(),
                                                    torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465),
                                                                                     (0.2023, 0.1994, 0.2010))
                                                ]), download=True)

else:
    test_dataset = torchvision.datasets.CIFAR100(root='./dataset/cifar100', train=False,
                                                 transform=torchvision.transforms.Compose([
                                                     torchvision.transforms.ToTensor(),
                                                     torchvision.transforms.Normalize((0.5071, 0.4867, 0.4408),
                                                                                      (0.2675, 0.2565, 0.2761))
                                                 ]), download=True)

test_data = DataLoader(test_dataset, batch_size=batchsize, shuffle=False, num_workers=4, pin_memory=True)

###暴力遍历

reslut = []
best_acc = 0.1
best_layer = []


def get(dim, dim_):
    if dim < dim_:
        return -1
    elif dim == dim_:
        return 0
    else:
        return 1


# reset_bn(model)

def train(config=None):
    with wandb.init(config=config):
        config = wandb.config
        channel = [get(config.dim1, 64), get(config.dim2, 128), get(config.dim3, 256), get(config.dim4, 512)]
        print(channel)
        tmp_acc = acc(model, test_data, len(test_dataset), [config.stage1, config.stage2, config.stage3, config.stage4],channel)
        wandb.log({"stage1": config.stage1, "stage2": config.stage2, "stage3": config.stage3, "stage4": config.stage4,
                   "dim1 ": config.dim1, "dim2 ": config.dim2, "dim3 ": config.dim3, "dim4 ": config.dim4,
                   "acc": tmp_acc})

wandb.agent(sweep_id, train, count=args.epoch)

