import argparse
import os

import torch

from models.MS_ResNet import msresnet18
from models.sbMS_ResNet import sb_msresnet18
import data_loaders
from functions import get_logger
import torchvision.models as models
from functions import BlockwiseDistillation, TET_loss
os.environ["CUDA_VISIBLE_DEVICES"] = "4,5"
from pack.util import *

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


parser = argparse.ArgumentParser(description='PyTorch Gated Attention Coding')
parser.add_argument('--start_epoch',
                    default=0,
                    type=int,
                    metavar='N',
                    help='manual epoch number (useful on restarts)')
parser.add_argument('-b',
                    '--batch_size',
                    default=2,
                    type=int,
                    metavar='N',
                    help='mini-batch size (default: 64), this is the total '
                         'batch size of all GPUs on the current node when '
                         'using Data Parallel or Distributed Data Parallel')

parser.add_argument('--lr',
                    '--learning_rate',
                    default=0.001,
                    type=float,
                    metavar='LR',
                    help='initial learning rate',
                    dest='lr')
parser.add_argument('--seed',
                    default=1000,
                    type=int,
                    help='seed for initializing training. ')
parser.add_argument('-T',
                    '--time',
                    default=6,
                    type=int,
                    metavar='N',
                    help='snn simulation time steps (default: 2)')
parser.add_argument('-j',
                    '--workers',
                    default=16,
                    type=int,
                    metavar='N',
                    help='number of data loading workers (default: 10)')
parser.add_argument('--epochs',
                    default=250,
                    type=int,
                    metavar='N',
                    help='number of total epochs to run')
parser.add_argument('--teacher', type=str, default='resnet18', help='path of ImageNet')
args = parser.parse_args()

def train(model, model_teacher, device, train_loader, criterion, criterion_kd, optimizer, epoch, args):
    running_loss = 0
    model.train()
    model_teacher.eval()
    total = 0
    correct = 0

    for i,(images, labels) in enumerate(train_loader):
        optimizer.zero_grad()

        labels = labels.to(device)
        images = images.to(device)

        # compute output
        outputs = model(images)
        logits_student = outputs.mean(1)
        tea = model_teacher(images)
        # loss = (criterion(logits_student, labels)+criterion_kd(logits_student,logits_teacher.mean(1)))/2
        # In training loop
        # loss = criterion(logits_student, labels) + 0.01 * model.module.get_mem_reg()
        loss = TET_loss(outputs, labels, criterion, 1, 0.05)+criterion_kd.compute_distillation_loss(outputs,tea,labels)
        running_loss += loss.item()
        loss.mean().backward()
        optimizer.step()
        total += float(labels.size(0))
        _, predicted = logits_student.cpu().max(1)
        correct += float(predicted.eq(labels.cpu()).sum().item())
    return running_loss, 100 * correct / total

@torch.no_grad()
def test(model, test_loader, device):
    correct = 0
    total = 0
    model.eval()
    for batch_idx, (inputs, targets) in enumerate(test_loader):
        inputs = inputs.to(device)
        outputs = model(inputs)
        mean_out = outputs.mean(1)
        _, predicted = mean_out.cpu().max(1)
        total += float(targets.size(0))
        correct += float(predicted.eq(targets).sum().item())
        if batch_idx % 100 == 0:
            acc = 100. * float(correct) / float(total)
            print(batch_idx, len(test_loader), ' Acc: %.5f' % acc)
    final_acc = 100 * correct / total
    return final_acc

def check_model_devices(model):
    for name, param in model.named_parameters():
        print(f"Parameter {name}: {param.device}")

def check_model_buffers(model):
    for name, buffer in model.named_buffers():
        print(f"Buffer {name}: {buffer.device}")

if __name__ == '__main__':

    seed_all(args.seed)
    CLASSES = 10
    train_dataset, val_dataset = data_loaders.build_cifar(use_cifar10=True)

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,
                                               num_workers=args.workers, pin_memory=True)
    test_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size,
                                              shuffle=False, num_workers=args.workers, pin_memory=True)


    model = sb_msresnet18(num_classes=10).cuda()
    print(model)
    model.T = 6
    parallel_model = torch.nn.DataParallel(model)
    # load pretain_model
    # state_dict = torch.load('.pth')
    # parallel_model.module.load_state_dict(state_dict, strict=False)

    logger = get_logger('.log')
    logger.info('start training!')

    # load teacher model
    model_teacher = msresnet18(num_classes=10).cuda()
    checkpoint = torch.load('./checkpoints/96.68-cifar10-ms18-fp.pth')
    model_teacher.load_state_dict(checkpoint)

    for p in model_teacher.parameters():
        p.requires_grad = False
    model_teacher.to(device)
    model_teacher = nn.DataParallel(model_teacher)
    model_teacher.eval()

    criterion = nn.CrossEntropyLoss().cuda()
    # criterion_smooth = CrossEntropyLabelSmooth(CLASSES, 0.1)
    # criterion_smooth = criterion_smooth.cuda()
    # criterion_kd = DistributionLoss()
    criterion_kd = BlockwiseDistillation(teacher_model=model_teacher, student_model=parallel_model)

    all_parameters = parallel_model.parameters()
    weight_parameters = []
    for pname, p in parallel_model.named_parameters():
        if p.ndimension() == 4 or 'conv' in pname:
            weight_parameters.append(p)
    weight_parameters_id = list(map(id, weight_parameters))
    other_parameters = list(filter(lambda p: id(p) not in weight_parameters_id, all_parameters))

    optimizer = torch.optim.Adam(
            [{'params' : other_parameters}, {'params' : weight_parameters, 'weight_decay': 0}], lr=0.001,)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, eta_min=0, T_max=args.epochs)

    best_acc = 0
    best_epoch = 0
    for epoch in range(args.epochs):
        loss, acc = train(parallel_model, model_teacher, device, train_loader, criterion, criterion_kd, optimizer,
                          epoch, args)
        logger.info('Epoch:[{}/{}]\t loss={:.5f}\t acc={:.3f}'.format(epoch , args.epochs, loss, acc ))
        scheduler.step()
        facc = test(parallel_model, test_loader, device)
        logger.info('Epoch:[{}/{}]\t Test acc={:.3f}'.format(epoch , args.epochs, facc ))

        if best_acc < facc:
            best_acc = facc
            best_epoch = epoch + 1
            torch.save(parallel_model.module.state_dict(), 'cifar10-new.pth')

        logger.info('Best Test acc={:.3f}'.format(best_acc ))
        logger.info('\n')