import argparse
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from models.MS_ResNet import msresnet18
from models.sbMS_ResNet import sb_msresnet18, HardBinaryConv2d
import data_loaders
from functions import TET_loss, get_logger, Mem_loss
from pack.util import *
from models.layers import mem_distill
import math


def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12345'
    dist.init_process_group("nccl", rank=rank, world_size=world_size)


def cleanup():
    dist.destroy_process_group()


parser = argparse.ArgumentParser(description='PyTorch Gated Attention Coding')
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
                    help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch_size', default=9, type=int, metavar='N',
                    help='mini-batch size per GPU')
parser.add_argument('--lr', '--learning_rate', default=0.001, type=float,
                    metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--seed', default=1000, type=int,
                    help='seed for initializing training. ')
parser.add_argument('-T', '--time', default=6, type=int, metavar='N',
                    help='snn simulation time steps (default: 2)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                    help='number of data loading workers per GPU')
parser.add_argument('--epochs', default=300, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('--teacher', type=str, default='resnet18',
                    help='path of ImageNet')
parser.add_argument('--world-size', default=3, type=int,
                    help='number of nodes for distributed training')
parser.add_argument('--rank', default=-1, type=int,
                    help='node rank for distributed training')
parser.add_argument('--local_rank', default=-1, type=int,
                    help='local rank for distributed training')
args = parser.parse_args()


def train(model, model_teacher, device, train_loader, criterion, optimizer, epoch):
    running_loss = 0
    model.train()
    model_teacher.eval()
    total = 0
    correct = 0
    total_epochs = args.epochs

    train_loader.sampler.set_epoch(epoch)  # Important for proper shuffling

    for i, (images, labels) in enumerate(train_loader):
        optimizer.zero_grad()

        labels = labels.to(device)
        images = images.to(device)

        # global mem_distill
        logis_stu = model(images)
        outputs = logis_stu.mean(1)

        # distillation
        stu_mem = mem_distill.copy()
        mem_distill.clear()
        logis_tea = model_teacher(images)
        tea_mem = mem_distill.copy()
        mem_distill.clear()

        loss_ce = TET_loss(logis_stu, labels, criterion, 1, 0.05)
        # loss_ce = criterion(outputs,labels)
        loss_kd = Mem_loss(tea_mem, stu_mem)
        # loss_kd = FR_loss(tea_mem, stu_mem)
        loss = loss_ce + loss_kd / (2 * (epoch + 1))

        # # 使用余弦退火调整权重
        # kd_weight = 0.8 * (1 + math.cos(epoch * math.pi / total_epochs)) + 0.2  # 从 1.0 逐渐降至 0.2
        # ce_weight = 0.2 + 0.8 * (1 - math.cos(epoch * math.pi / total_epochs))  # 从 0.2 逐渐升至 1.0
        # if epoch<(total_epochs/5):
        #     kd_weight = 0.8 * (1 + math.cos(5 * epoch * math.pi / total_epochs)) + 0.2
        #     ce_weight = 0.2 + 0.8 * (1 - math.cos(5 * epoch * math.pi / total_epochs))
        # else:
        #     kd_weight = 0.2
        #     ce_weight = 1.8
        # weighted_loss_kd = (kd_weight * loss_kd).mean()
        # weighted_loss_ce = (ce_weight * loss_ce).mean()
        # loss = weighted_loss_kd+weighted_loss_ce

        running_loss += loss.item()
        loss.mean().backward()
        # check_gradients(model)
        optimizer.step()

        total += float(labels.size(0))
        _, predicted = outputs.cpu().max(1)
        correct += float(predicted.eq(labels.cpu()).sum().item())

    # Gather metrics from all processes
    metrics = torch.tensor([running_loss, correct, total], device=device)
    dist.all_reduce(metrics, op=dist.ReduceOp.SUM)
    running_loss, correct, total = metrics.tolist()

    return running_loss, 100 * correct / total


@torch.no_grad()
def test(model, test_loader, device):
    correct = 0
    total = 0
    model.eval()

    for batch_idx, (inputs, targets) in enumerate(test_loader):
        inputs = inputs.to(device)
        outputs = model(inputs)
        mean_out = outputs.mean(1)
        _, predicted = mean_out.cpu().max(1)
        total += float(targets.size(0))
        correct += float(predicted.eq(targets).sum().item())

        if batch_idx % 100 == 0 and dist.get_rank() == 0:
            acc = 100. * float(correct) / float(total)
            print(batch_idx, len(test_loader), ' Acc: %.5f' % acc)

    # Gather metrics from all processes
    metrics = torch.tensor([correct, total], device=device)
    dist.all_reduce(metrics, op=dist.ReduceOp.SUM)
    correct, total = metrics.tolist()

    final_acc = 100 * correct / total
    return final_acc


def main_worker(gpu, ngpus_per_node, args):
    args.gpu = gpu
    args.rank = gpu  # Assuming single node

    print(f'Using GPU: {gpu} for training')

    setup(args.rank, args.world_size)

    torch.cuda.set_device(args.gpu)
    seed_all(args.seed + args.rank)

    CLASSES = 100
    train_dataset, val_dataset = data_loaders.build_cifar(use_cifar10=False)

    train_sampler = DistributedSampler(train_dataset)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        num_workers=args.workers,
        pin_memory=True,
        sampler=train_sampler
    )

    test_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=True
    )

    # Create model on GPU
    model = sb_msresnet18(num_classes=100).cuda(args.gpu)
    model.T = args.time
    model = DDP(model, device_ids=[args.gpu],find_unused_parameters=True)

    logger = get_logger('cifar100-res19-fr_kd.log')
    logger.info('start training!')
    # logger.info('load success, test accuracy:', facc)

    # load teacher model
    model_teacher = msresnet18(num_classes=100).cuda(args.gpu)
    # checkpoint = torch.load('./checkpoints/81.2-cifar100-ms18-fp.pth', map_location=f'cuda:{args.gpu}')
    # model_teacher.load_state_dict(checkpoint)

    for p in model_teacher.parameters():
        p.requires_grad = False
    model_teacher = model_teacher
    model_teacher.eval()

    criterion = nn.CrossEntropyLoss().cuda(args.gpu)

    # Optimizer setup
    all_parameters = model.parameters()
    weight_parameters = []
    for pname, p in model.named_parameters():
        if p.ndimension() == 4 and 'conv' in pname:
            weight_parameters.append(p)
    weight_parameters_id = list(map(id, weight_parameters))
    other_parameters = list(filter(lambda p: id(p) not in weight_parameters_id, all_parameters))
    # optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-3)
    optimizer = torch.optim.Adam(
        [{'params': other_parameters},
         {'params': weight_parameters, 'weight_decay': 0}], lr=args.lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, eta_min=0, T_max=args.epochs)

    best_acc = 0

    for epoch in range(args.epochs):
        loss, acc = train(model, model_teacher, args.gpu, train_loader, criterion, optimizer, epoch)

        if args.rank == 0:
            logger.info('Epoch:[{}/{}]\t loss={:.5f}\t acc={:.3f}'.format(epoch, args.epochs, loss, acc))

        scheduler.step()
        facc = test(model, test_loader, args.gpu)

        if args.rank == 0:
            logger.info('Epoch:[{}/{}]\t Test acc={:.3f}'.format(epoch, args.epochs, facc))

            if best_acc < facc:
                best_acc = facc
                p = []
                for name, layer in model.named_modules():
                    if isinstance(layer, HardBinaryConv2d):
                        p.append(layer.binarize.patterns)
                torch.save(p, './outputs/p_' + str(epoch) + '.pth')
                torch.save(model.module.state_dict(), 'cifar100-res19-fr_kd.pth')

            logger.info('Best Test acc={:.3f}'.format(best_acc))
            logger.info('\n')

    cleanup()

def gradient_check_hook(name):
    def hook(grad):
        if grad is None:
            print(f"参数 {name} 的梯度为 None")
        elif torch.all(grad == 0):
            print(f"参数 {name} 的梯度全为 0")
    return hook

def check_gradients(model):
    for name, param in model.named_parameters():
        if param.grad is None:
            print(f"参数 {name} 没有梯度")
        elif torch.all(param.grad == 0):
            print(f"参数 {name} 梯度为零")
if __name__ == '__main__':
    ngpus_per_node = torch.cuda.device_count()
    args.world_size = ngpus_per_node
    mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))