
# -*- coding: utf-8 -*-
import argparse
import os
from time import time
os.system('wandb login f1ff739b893fd48fb835c7cb39cbe54968b34c44')
import wandb
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader
import torchvision
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import math
import torch.nn as nn
from spikingjelly.clock_driven import functional ,neuron
import torch
from utlis import seed_all ,GradualWarmupScheduler
from models import ResNet_cnn ,ResNet_snn ,supernet_snn
from loss_fun import feature_loss ,logits_loss
seed = 1000
seed_all(seed)
parser = argparse.ArgumentParser(description="NAS")
# parser.add_argument("--local_rank", type=int, default=0)
# parser.add_argument("--nums_gpu", type=int, default=4)

parser.add_argument("--gpu", type=int, default=0)
parser.add_argument("--batch", type=int, default=128)
parser.add_argument("--epoch", type=int, default=200)
parser.add_argument("--CIFAR", type=int, default=100)
parser.add_argument("--lr", type=float, default=0.01)
parser.add_argument("--names", type=str, default='Yun_DNA')



parser.add_argument("--early", action='store_true', default=False)
parser.add_argument("--a", type=float, default=0.)
parser.add_argument("--b", type=float, default=100.)
parser.add_argument("--c", type=float, default=1.)
parser.add_argument("--T", type=int, default=20)



args = parser.parse_args()
# torch.distributed.init_process_group(backend='nccl')
# torch.cuda.set_device(args.local_rank)
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)

func_dict = {
    '18': [ResNet_snn.resnet18__, ResNet_cnn.resnet18],
    '34': [ResNet_snn.resnet34__, ResNet_cnn.resnet34],
    '50': [ResNet_snn.resnet50__, ResNet_cnn.resnet50],
    '101': [ResNet_snn.resnet101__, ResNet_cnn.resnet101],
    '152': [ResNet_snn.resnet152__, ResNet_cnn.resnet152],
}

a = args.a
feature_b = args.b
logitc_c = args.c
logit_T = args.T
warm_up = True
CIFAR = args.CIFAR
bathsize = args.batch
epoch = args.epoch
best_acc = 0.77
sta_time = time()
models = '18'
teacher_model = '34'
names = args.names + '_' + str(bathsize) + '_' + str(epoch) + '_' + str(CIFAR) + '_a_' + str(a) + '_b_'+ str(feature_b) + '_lr_' + str(args.lr)
# names = 'seed:{}'.format(seed) +'baseline' + teacher_model + 'to' + models + 'stage_after_spike_b_' + str(feature_b) +  '_a_'+str(a)+'_t20_c_' + str(logitc_c) + 'logits_T_' + str(logit_T)
# if args.local_rank == 0:
wandb.init(project="distil_snn", name=names, group="DNA")
# writer = SummaryWriter("./Ablation/"+names)

##### 加载模型 ######
### teacher
_ , t = func_dict[teacher_model]
teacher = t(num_classes = CIFAR).cuda()
teacher.load_state_dict(torch.load("./model_weight/resnet"+ teacher_model + "_cnn_cifar" + str(CIFAR) +"_SGD_CE.pth"))

### student
lenet = supernet_snn.sup_resnet82__(num_classes=CIFAR).cuda()

# lenet = torch.nn.parallel.DistributedDataParallel(lenet, device_ids=[[args.local_rank]],output_device=[2], find_unused_parameters=True)
# teacher = torch.nn.parallel.DistributedDataParallel(teacher, device_ids=[[args.local_rank]],output_device=[2], find_unused_parameters=True)

n_parameters = sum(p.numel() for p in lenet.parameters() if p.requires_grad)
print('number of params:', n_parameters)
print(lenet)

#######

optimer = torch.optim.AdamW(params=lenet.parameters(), lr=args.lr, betas=(0.9, 0.999), weight_decay=5e-3,eps=1e-3)
###### loss + optim + scheduler ######
loss_fun = nn.CrossEntropyLoss().cuda()

scheduler = CosineAnnealingLR(optimer, T_max=epoch, eta_min=0)
scaler = torch.cuda.amp.GradScaler()

if warm_up:
    scheduler_warmup = GradualWarmupScheduler(optimer, multiplier=1, total_epoch=5, after_scheduler=scheduler)
else:
    scheduler_warmup = None

#######数据
if CIFAR == 10:
    train_dataset = torchvision.datasets.CIFAR10(root='./dataset/cifar10', train=True,
                                                 transform=torchvision.transforms.Compose([
                                                     torchvision.transforms.RandomCrop((32, 32), padding=4),
                                                     torchvision.transforms.RandomHorizontalFlip(),
                                                     torchvision.transforms.ToTensor(),
                                                     torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465),
                                                                                      (0.2023, 0.1994, 0.2010)),
                                                 ]), download=True)



else:
    train_dataset = torchvision.datasets.CIFAR100(root='./dataset/cifar100', train=True,
                                                  transform=torchvision.transforms.Compose([
                                                      torchvision.transforms.RandomCrop((32, 32), padding=4),
                                                      torchvision.transforms.RandomHorizontalFlip(),
                                                      torchvision.transforms.ToTensor(),
                                                      torchvision.transforms.Normalize((0.5071, 0.4867, 0.4408),
                                                                                       (0.2675, 0.2565, 0.2761))
                                                  ]), download=True)




train_data = DataLoader(train_dataset, batch_size=bathsize, shuffle=True, num_workers=4, pin_memory=True)


if __name__ == '__main__':
    for i in range(epoch):
        loss_all = 0
        loss_kd_all = 0
        layers_all = 0
        start_time = time()
        right = 0
        right_big = 0
        right_small = 0
        lenet.train()
        teacher.eval()
        # samper_train.set_epoch(epoch)
        for step, (imgs, traget) in enumerate(train_data):
            imgs, traget = imgs.cuda(), traget.cuda()
            with torch.cuda.amp.autocast():
                with torch.no_grad():
                    label, feature_tea = teacher(imgs)

                layer = []
                layer.append(torch.randint(low=2, high=6, size=(1,)).item())
                layer.append(torch.randint(low=2, high=6, size=(1,)).item())
                layer.append(torch.randint(low=2, high=9, size=(1,)).item())
                layer.append(torch.randint(low=2, high=6, size=(1,)).item())

                chanel = []
                chanel.append(torch.randint(low=-1, high=2, size=(1,)).item())
                chanel.append(torch.randint(low=-1, high=2, size=(1,)).item())
                chanel.append(torch.randint(low=-1, high=2, size=(1,)).item())
                chanel.append(torch.randint(low=-1, high=2, size=(1,)).item())

                output, feature_stu, output_small, feature_small, output_big, feature_big = lenet(imgs, layer, chanel)


                loss_ce = loss_fun(output, traget)
                loss_kd = feature_loss(feature_stu, feature_tea) + feature_loss(feature_big,feature_tea) + feature_loss(feature_small, feature_tea)

                loss = (loss_fun(output, traget) + loss_fun(output_small, traget) + loss_fun(output_big,traget)) * a + loss_kd * feature_b + (logits_loss(output, label, logit_T) + logits_loss(output_small, label,logit_T) + logits_loss(output_big,label,logit_T)) * logitc_c

            # layers_all += layer

            loss_all += loss_ce.item()
            loss_kd_all += loss_kd.item()
            right = (output.argmax(1) == traget).sum() + right
            right_big = (output_big.argmax(1) == traget).sum() + right_big
            right_small = (output_small.argmax(1) == traget).sum() + right_small
            optimer.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(optimer)
            scaler.update()
            functional.reset_net(lenet)
        accuracy_train = right / (len(train_dataset))
        accuracy_big = right_big / (len(train_dataset))
        accuracy_small = right_small / (len(train_dataset))

        if warm_up:
            if args.early:
                if i < 10 or optimer.param_groups[0]['lr'] > args.lr * 0.05:
                    scheduler_warmup.step()
            else:
                scheduler_warmup.step()
        #             scheduler_warmup.step()
        else:
            scheduler.step()

        print("epoch:{} time:{:.0f}  loss_ce :{:.2f} loss_kd :{:.2f} train_acc:{:.4f} lr:{} eta:{:.2f}".format(i + 1,
               time() - start_time,
               loss_all,
               loss_kd_all,
               accuracy_train,
               optimer.param_groups[0]['lr'],(time() - start_time) * ( epoch - i - 1) / 3600))
        print({'layers': layers_all})
        wandb.log({"train_acc_rand": accuracy_train, "train_acc_big": accuracy_big, "train_acc_small": accuracy_small,
                   "loss_ce": loss_all, "lr": optimer.param_groups[0]['lr']})

    torch.save(lenet.state_dict(), "./model_weight/" + names + ".pth")
    print("模型已保存")
    end_ = time()
    print(end_ - sta_time)
    print("best_acc:{:.4f}".format(best_acc))




