import argparse
import os

os.system('wandb login xxx')
import wandb
from time import time
import torch
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader
import torchvision
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data.distributed import DistributedSampler
from utils import seed_all, GradualWarmupScheduler
from torchvision.models.resnet import resnet18
from torchvision import transforms
from models import ImageNet_cnn, ImageNet_snn
from loss_kd import feature_loss,  logits_loss
from spikingjelly.clock_driven import functional


seed_all(1000)

func_dict = {
    '18': [ImageNet_snn.resnet18_, ImageNet_cnn.resnet18],
    '34': [ImageNet_snn.resnet34_, ImageNet_cnn.resnet34],
    '50': [ImageNet_snn.resnet50_, ImageNet_cnn.resnet50],
    '101': [ImageNet_snn.resnet101_, ImageNet_cnn.resnet101],
}


def get_model(name):
    return func_dict[name]


parser = argparse.ArgumentParser(description="ImageNet_SNN_Training")
parser.add_argument("--local_rank", type=int, default=0)
parser.add_argument("--datapath", type=str, default='/remote-home/share/dataset/ImageNet2012/')
parser.add_argument("--model", type=str, default='18')
parser.add_argument("--tea_model", type=str, default='18')
parser.add_argument("--model_weight", type=str, default="resnet50_timm.pth")
parser.add_argument("--batch", type=int, default=128)
parser.add_argument("--epoch", type=int, default=120)
parser.add_argument("--warm_up", action='store_true', default=False)
parser.add_argument("--beta", type=float, default=10.)
parser.add_argument("--after_beta", type=float, default=0.01)
parser.add_argument("--load_weight", action='store_true', default=False)
parser.add_argument("--feature_epochs", type=int, default=10)


args = parser.parse_args()
torch.distributed.init_process_group(backend='nccl')
torch.cuda.set_device(args.local_rank)

best_acc = 0.6
sta = time()

# ----------------------------
if args.local_rank == 0:
    print("batch:{}_epoch:{}_lr:{}_".format(args.batch, args.epoch, args.lr))
    print("Model downloading")
    wandb.init(project="distil_snn", name=args.names + args.model + 'kd_' + str(args.kd) + 'beta_{}'.format(args.beta),group="ImageNet")

student_fun, _ = get_model(args.model)
_ , teacher_fun = get_model(args.tea_model)

SNN = student_fun(num_classes = 1000).cuda()
ANN = teacher_fun(num_classes = 1000).cuda()
SNN.load_state_dict(torch.load('resnet' + args.model + '_timm.pth'))

######## 加载权重 #######
if args.load_weight:
    SNN.load_state_dict(torch.load('resnet' + args.model + '_timm.pth'), strict=False)
    print("load_weight_success")

n_parameters = sum(p.numel() for p in SNN.parameters() if p.requires_grad)
if args.local_rank == 0:
    print('number of params:', n_parameters)
    print(SNN)

SNN = torch.nn.parallel.DistributedDataParallel(SNN, device_ids=[[args.local_rank]],output_device=[args.local_rank], find_unused_parameters=False)

loss_fun = torch.nn.CrossEntropyLoss().cuda()
scaler = torch.cuda.amp.GradScaler()

optimer = torch.optim.AdamW(params=SNN.parameters(), lr=1e-3, betas=(0.9, 0.999), weight_decay=5e-3)


scheduler = CosineAnnealingLR(optimer, T_max=args.epoch, eta_min=0)
scheduler_warm = None
if args.warm_up:
    scheduler_warm = GradualWarmupScheduler(optimer, multiplier=1, total_epoch=5, after_scheduler=scheduler)
writer = None

# ------------------------------
if args.local_rank == 0:
    print("Datasets Download")

traindir = args.datapath + 'train'
valdir = args.datapath + 'val'
###数据

train_dataset = torchvision.datasets.ImageFolder(
    traindir,
    transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ]))

test_dataset = torchvision.datasets.ImageFolder(
    valdir,
    transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ]))

samper_train = DistributedSampler(train_dataset)
train_data = DataLoader(train_dataset, batch_size=args.batch, sampler=samper_train, num_workers=args.num_gpu * 5,
                        pin_memory=True)
test_data = DataLoader(test_dataset, batch_size=args.batch, shuffle=False, num_workers=args.num_gpu * 5,
                       pin_memory=True)

if __name__ == '__main__':

    for i in range(args.epoch):
        loss_ce_all = 0
        start_time = time()
        loss_logit_all = 0
        loss_feature_all = 0
        right = 0
        SNN.train()
        samper_train.set_epoch(args.epoch)

        for step, (imgs, target) in enumerate(train_data):
            imgs, target = imgs.cuda(non_blocking=True), target.cuda(non_blocking=True)
            with torch.cuda.amp.autocast():
                output , feature_snn  = SNN(imgs)
                with torch.no_grad():
                    lables, feature_ann = ANN(imgs)

                loss_ce = loss_fun(output,target)
                loss_feature = feature_loss(feature_snn,feature_ann)
                loss_logit = logits_loss(output,lables,T=1)

                if i > args.feature_epochs:
                    loss = loss_ce + loss_feature * args.after_beta + loss_logit * 10
                else:
                    loss = loss_ce + loss_feature * args.beta + loss_logit * 10

            right = (output.argmax(1) == target).sum() + right
            loss_ce_all += loss.item()
            loss_logit_all += loss_logit.item()
            loss_feature_all += loss_feature.item()

            optimer.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(optimer)
            scaler.update()
            functional.reset_net(SNN)

            if step % 100 == 0 and args.local_rank == 0:
                print("step:{:.2f} loss_ce:{:.2f}".format(step / len(train_data), loss_ce.item()))
        accuracy1 = right / (len(train_dataset)) * args.num_gpu
        if args.warm_up:
            scheduler_warm.step()
        else:
            scheduler.step()

        SNN.eval()
        right = 0

        with torch.no_grad():
            for (imgs, target) in test_data:
                imgs, target = imgs.cuda(non_blocking=True), target.cuda(non_blocking=True)
                output  = SNN(imgs)
                right = (output.argmax(1) == target).sum() + right
                functional.reset_net(SNN)

            accuracy = right / len(test_dataset)
            end_time = time()
            if args.local_rank == 0:
                print("epoch:{} time:{:.0f}  loss:{:.4f} train_acc:{:.4f} tets_acc:{:.4f} eta:{:.2f}".format(i + 1,end_time - start_time,loss_all,accuracy1,accuracy, (end_time - start_time) * (args.epoch - i - 1) / 3600))
                if accuracy > best_acc:
                    best_acc = accuracy
                    print("best_acc:{:.4f}".format(best_acc))
                wandb.log({"test_acc": accuracy, "train_acc": accuracy1, "loss_ce_all": loss_ce_all, 'epoch': i,"loss_logit_all": loss_logit_all, "loss_feature_all": loss_feature_all, })

    end = time()
    print(end - sta)
    print("best_acc:{:.4f}".format(best_acc))