import argparse
import os
from models.MS_ResNet import msresnet18
from models.sbMS_ResNet import sb_msresnet18
import data_loaders
from functions import BlockwiseDistillation, TET_loss, get_logger, MembraneDistiller, Mem_loss
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
from pack.util import *
from models.layers import mem_distill

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


parser = argparse.ArgumentParser(description='PyTorch Gated Attention Coding')
parser.add_argument('--start_epoch',
                    default=0,
                    type=int,
                    metavar='N',
                    help='manual epoch number (useful on restarts)')
parser.add_argument('-b',
                    '--batch_size',
                    default=9,
                    type=int,
                    metavar='N',
                    help='mini-batch size (default: 64), this is the total '
                         'batch size of all GPUs on the current node when '
                         'using Data Parallel or Distributed Data Parallel')

parser.add_argument('--lr',
                    '--learning_rate',
                    default=0.001,
                    type=float,
                    metavar='LR',
                    help='initial learning rate',
                    dest='lr')
parser.add_argument('--seed',
                    default=1000,
                    type=int,
                    help='seed for initializing training. ')
parser.add_argument('-T',
                    '--time',
                    default=6,
                    type=int,
                    metavar='N',
                    help='snn simulation time steps (default: 2)')
parser.add_argument('-j',
                    '--workers',
                    default=16,
                    type=int,
                    metavar='N',
                    help='number of data loading workers (default: 10)')
parser.add_argument('--epochs',
                    default=300,
                    type=int,
                    metavar='N',
                    help='number of total epochs to run')
parser.add_argument('--teacher', type=str, default='resnet18', help='path of ImageNet')
args = parser.parse_args()

def train(model, model_teacher, device, train_loader, criterion, optimizer):
    running_loss = 0
    model.train()
    model_teacher.eval()
    total = 0
    correct = 0

    for i,(images, labels) in enumerate(train_loader):
        optimizer.zero_grad()

        labels = labels.to(device)
        images = images.to(device)

        # compute output
        global mem_distill
        logis_stu = model(images)
        outputs = logis_stu.mean(1)
        stu_mem = mem_distill.copy()
        mem_distill.clear()
        logis_tea = model_teacher(images)
        tea_mem = mem_distill.copy()
        mem_distill.clear()
        # loss = (criterion_kd(logits_student, logits_teacher.mean(1))+TET_loss(outputs, labels, criterion, 1, 0.05))/2
        loss_ce = TET_loss(logis_stu, labels, criterion, 1, 0.05)
        loss_kd = Mem_loss(criterion, tea_mem, stu_mem, logis_stu, logis_tea)
        loss = loss_kd+loss_ce
        running_loss += loss.item()
        loss.mean().backward()
        optimizer.step()
        total += float(labels.size(0))
        _, predicted = outputs.cpu().max(1)
        correct += float(predicted.eq(labels.cpu()).sum().item())
    return running_loss, 100 * correct / total

@torch.no_grad()
def test(model, test_loader, device):
    correct = 0
    total = 0
    model.eval()
    for batch_idx, (inputs, targets) in enumerate(test_loader):
        inputs = inputs.to(device)
        outputs = model(inputs)
        mean_out = outputs.mean(1)
        _, predicted = mean_out.cpu().max(1)
        total += float(targets.size(0))
        correct += float(predicted.eq(targets).sum().item())
        if batch_idx % 100 == 0:
            acc = 100. * float(correct) / float(total)
            print(batch_idx, len(test_loader), ' Acc: %.5f' % acc)
    final_acc = 100 * correct / total
    return final_acc

if __name__ == '__main__':

    seed_all(args.seed)
    CLASSES = 100
    train_dataset, val_dataset = data_loaders.build_cifar(use_cifar10=False)

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,
                                               num_workers=args.workers, pin_memory=True)
    test_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size,
                                              shuffle=False, num_workers=args.workers, pin_memory=True)


    model = sb_msresnet18(num_classes=100).cuda()
    print(model)
    model.T = args.time
    # parallel_model = torch.nn.DataParallel(model)
    # load pretain_model
    print(model.state_dict())
    state_dict = torch.load('./checkpoints/success/77.43-cifar100-ours-4/cifar100-res19-4.pth')
    model.load_state_dict(state_dict, strict=False)
    # parallel_model.module.load_state_dict(state_dict, strict=True)

    logger = get_logger('.log')
    logger.info('start training!')

    # load teacher model
    model_teacher = msresnet18(num_classes=100).cuda()
    checkpoint = torch.load('./checkpoints/81.2-cifar100-ms18-fp.pth')
    model_teacher.load_state_dict(checkpoint)

    for p in model_teacher.parameters():
        p.requires_grad = False
    model_teacher.to(device)
    model_teacher = nn.DataParallel(model_teacher)
    model_teacher.eval()

    criterion = nn.CrossEntropyLoss().cuda()
    criterion_smooth = CrossEntropyLabelSmooth(CLASSES, 0.1).cuda()
    # criterion_kd = DistributionLoss()
    # criterion_kd = BlockwiseDistillation(teacher_model=model_teacher, student_model=parallel_model)
    # criterion_kd = MembraneDistiller(student_model=parallel_model, teacher_model=model_teacher)

    all_parameters = parallel_model.parameters()
    weight_parameters = []
    for pname, p in parallel_model.named_parameters():
        if p.ndimension() == 4 and 'conv' in pname:
            weight_parameters.append(p)
    weight_parameters_id = list(map(id, weight_parameters))
    other_parameters = list(filter(lambda p: id(p) not in weight_parameters_id, all_parameters))

    optimizer = torch.optim.Adam(
            [{'params' : other_parameters}, {'params' : weight_parameters, 'weight_decay': 0}], lr=0.001)
    # optimizer = torch.optim.SGD([{'params' : other_parameters, 'weight_decay': 0},
    #                              {'params' : weight_parameters, 'weight_decay': 1e-4}],
    #                             lr=0.1, momentum=0.9)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, eta_min=0, T_max=args.epochs)

    best_acc = 0
    best_epoch = 0
    for epoch in range(1):
        # loss, acc = train(parallel_model, model_teacher, device, train_loader, criterion, optimizer)
        # logger.info('Epoch:[{}/{}]\t loss={:.5f}\t acc={:.3f}'.format(epoch , args.epochs, loss, acc ))
        scheduler.step()
        facc = test(parallel_model, test_loader, device)
        logger.info('Epoch:[{}/{}]\t Test acc={:.3f}'.format(epoch , args.epochs, facc ))

        # if best_acc < facc:
        #     best_acc = facc
        #     best_epoch = epoch + 1
        #     torch.save(parallel_model.module.state_dict(), 'cifar100-1211-kd.pth')

        logger.info('Best Test acc={:.3f}'.format(best_acc ))
        logger.info('\n')