import argparse
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from models.vgg import vggsnn
from models.sb_vgg import sbvggsnn
import data_loaders
from functions import BlockwiseDistillation, TET_loss, get_logger, MembraneDistiller, Mem_loss
from pack.util import *
from models.layers import mem_distill
import math


def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '13345'
    dist.init_process_group("nccl", rank=rank, world_size=world_size)


def cleanup():
    dist.destroy_process_group()


parser = argparse.ArgumentParser(description='PyTorch Gated Attention Coding')
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
                    help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch_size', default=9, type=int, metavar='N',
                    help='mini-batch size per GPU')
parser.add_argument('--lr', '--learning_rate', default=0.001, type=float,
                    metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--seed', default=1000, type=int,
                    help='seed for initializing training. ')
parser.add_argument('-T', '--time', default=6, type=int, metavar='N',
                    help='snn simulation time steps (default: 2)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                    help='number of data loading workers per GPU')
parser.add_argument('--epochs', default=300, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('--teacher', type=str, default='resnet18',
                    help='path of ImageNet')
parser.add_argument('--world-size', default=3, type=int,
                    help='number of nodes for distributed training')
parser.add_argument('--rank', default=-1, type=int,
                    help='node rank for distributed training')
parser.add_argument('--local_rank', default=-1, type=int,
                    help='local rank for distributed training')
args = parser.parse_args()


def train(model, model_teacher, device, train_loader, criterion, criterion_kd, optimizer, epoch):
    running_loss = 0
    model.train()
    model_teacher.eval()
    total = 0
    correct = 0
    total_epochs = args.epochs

    train_loader.sampler.set_epoch(epoch)  # Important for proper shuffling

    for i, (images, labels) in enumerate(train_loader):
        optimizer.zero_grad()

        labels = labels.to(device)
        images = images.to(device)

        # global mem_distill
        logis_stu = model(images)
        outputs = logis_stu.mean(1)

        # distillation
        stu_mem = mem_distill.copy()
        mem_distill.clear()
        logis_tea = model_teacher(images)
        tea_mem = mem_distill.copy()
        mem_distill.clear()

        loss_ce = TET_loss(logis_stu, labels, criterion, 1, 0.001)
        loss_kd = Mem_loss(tea_mem, stu_mem)
        loss = loss_ce + loss_kd/(2*(epoch + 1))

        running_loss += loss.item()
        loss.mean().backward()
        # check_gradients(model)
        optimizer.step()

        total += float(labels.size(0))
        _, predicted = outputs.cpu().max(1)
        correct += float(predicted.eq(labels.cpu()).sum().item())

    # Gather metrics from all processes
    metrics = torch.tensor([running_loss, correct, total], device=device)
    dist.all_reduce(metrics, op=dist.ReduceOp.SUM)
    running_loss, correct, total = metrics.tolist()

    return running_loss, 100 * correct / total


@torch.no_grad()
def test(model, test_loader, device):
    correct = 0
    total = 0
    model.eval()

    for batch_idx, (inputs, targets) in enumerate(test_loader):
        inputs = inputs.to(device)
        outputs = model(inputs)
        mean_out = outputs.mean(1)
        _, predicted = mean_out.cpu().max(1)
        total += float(targets.size(0))
        correct += float(predicted.eq(targets).sum().item())

        if batch_idx % 100 == 0 and dist.get_rank() == 0:
            acc = 100. * float(correct) / float(total)
            print(batch_idx, len(test_loader), ' Acc: %.5f' % acc)

    # Gather metrics from all processes
    metrics = torch.tensor([correct, total], device=device)
    dist.all_reduce(metrics, op=dist.ReduceOp.SUM)
    correct, total = metrics.tolist()

    final_acc = 100 * correct / total
    return final_acc


def main_worker(gpu, ngpus_per_node, args):
    args.gpu = gpu
    args.rank = gpu  # Assuming single node

    print(f'Using GPU: {gpu} for training')

    setup(args.rank, args.world_size)

    torch.cuda.set_device(args.gpu)
    seed_all(args.seed + args.rank)

    train_dataset, val_dataset = data_loaders.build_dvscifar10(T=10)

    train_sampler = DistributedSampler(train_dataset)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        num_workers=args.workers,
        pin_memory=True,
        sampler=train_sampler
    )

    test_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=True
    )

    # Create model on GPU
    model = sbvggsnn(num_classes=10).cuda(args.gpu)
    model.T = args.time
    model = DDP(model, device_ids=[args.gpu],find_unused_parameters=True)

    logger = get_logger('dvscifar10-vgg-4.log')
    logger.info('start training!')
    # logger.info('load success, test accuracy:', facc)

    # load teacher model
    model_teacher = vggsnn(num_classes=10).cuda(args.gpu)
    checkpoint = torch.load('./checkpoints/dvscifar10-vggsnn-teacher.pth.tar', map_location=f'cuda:{args.gpu}')
    model_teacher.load_state_dict(checkpoint['state_dict'])

    for p in model_teacher.parameters():
        p.requires_grad = False
    model_teacher = model_teacher
    model_teacher.eval()

    criterion = nn.CrossEntropyLoss().cuda(args.gpu)
    criterion_kd = DistributionLoss().cuda(args.gpu)

    # Optimizer setup
    all_parameters = model.parameters()
    weight_parameters = []
    for pname, p in model.named_parameters():
        if p.ndimension() == 4 and 'conv' in pname:
            weight_parameters.append(p)
    weight_parameters_id = list(map(id, weight_parameters))
    other_parameters = list(filter(lambda p: id(p) not in weight_parameters_id, all_parameters))
    # optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-3)
    optimizer = torch.optim.Adam(
        [{'params': other_parameters},
         {'params': weight_parameters, 'weight_decay': 0}], lr=args.lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, eta_min=0, T_max=args.epochs)

    best_acc = 0

    for epoch in range(args.epochs):
        loss, acc = train(model, model_teacher, args.gpu, train_loader, criterion, criterion_kd, optimizer, epoch)

        if args.rank == 0:
            logger.info('Epoch:[{}/{}]\t loss={:.5f}\t acc={:.3f}'.format(epoch, args.epochs, loss, acc))

        scheduler.step()
        facc = test(model, test_loader, args.gpu)

        if args.rank == 0:
            logger.info('Epoch:[{}/{}]\t Test acc={:.3f}'.format(epoch, args.epochs, facc))

            if best_acc < facc:
                best_acc = facc
                torch.save(model.module.state_dict(), 'dvscifar10-vgg-4.pth')

            logger.info('Best Test acc={:.3f}'.format(best_acc))
            logger.info('\n')

    cleanup()

def gradient_check_hook(name):
    def hook(grad):
        if grad is None:
            print(f"参数 {name} 的梯度为 None")
        elif torch.all(grad == 0):
            print(f"参数 {name} 的梯度全为 0")
    return hook

def check_gradients(model):
    for name, param in model.named_parameters():
        if param.grad is None:
            print(f"参数 {name} 没有梯度")
        elif torch.all(param.grad == 0):
            print(f"参数 {name} 梯度为零")
if __name__ == '__main__':
    ngpus_per_node = torch.cuda.device_count()
    args.world_size = ngpus_per_node
    mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))