import os
import sys
import argparse
from os.path import join, expanduser

import torch
import torch.nn as nn
import torch.nn.functional as F

import data
import util
import model

if __name__ == "__main__":
    # ===========================================================
    # settings
    # ===========================================================
    
    parser = argparse.ArgumentParser(description='')

    parser.add_argument('--dataset_name', type=str, default='imagenet', help='cifar10 or cifar100 or imagenet')
    parser.add_argument('--arch_name', type=str, default='vgg16', help='vgg16 or resnet50')
    parser.add_argument('--baseline', type=str, default='random', help='zero or mean (which is per channel mean)')
    parser.add_argument('--relu', default=False, action='store_true', help='last layer relu existance')
    parser.add_argument('--sep', default=False, action='store_true', help='last layer separation existance')
    parser.add_argument('--distill', default=False, action='store_true', help='prediction vs correct classifier')
    parser.add_argument('--prior', type=str, default='tv', help='standard or tv')
    parser.add_argument('--bs', type=int, default=64)
    parser.add_argument('--epochs',     type=int, default=10)
    parser.add_argument('--beta',     type=float, default=10, help='beta-VAE')
    parser.add_argument('--runs_dir', type=str, default='runs')
    parser.add_argument('--resume', action='store_true', help='resume from checkpoint')


    args = parser.parse_args()

    args.runs_dir = join('runs', args.runs_dir, f'{args.dataset_name}_{args.arch_name}', f'bs_{args.baseline}')
    if not os.path.exists(args.runs_dir):
        os.makedirs(args.runs_dir)

    # def write_log(string):
    #     with open(join(args.runs_dir,'log.log'), 'a') as lf:
    #         sys.stdout = lf
    #         print(string)

    # os.environ["CUDA_VISIBLE_DEVICES"] = '0, 1, 2, 3'
    # os.environ["CUDA_VISIBLE_DEVICES"] = '4, 5, 6, 7'
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(device,'\n')
    best_acc = 0.
    start_epoch = 0

    # ===========================================================
    # main
    # ===========================================================

    ''' Data '''
    print('==> Preparing data..')
    if args.dataset_name == 'cifar10':
        train_loader, test_loader = data.get_cifar10_loader(bs=args.bs)
        num_classes = 10
        size = 32
    elif args.dataset_name == 'cifar100':
        train_loader, test_loader = data.get_cifar100_loader(bs=args.bs)
        num_classes = 100
        size = 32
    elif args.dataset_name == 'imagenet':
        train_loader, test_loader = data.get_imagenet_loader(bs=args.bs)
        num_classes = 1000
        size = 224

    ''' Build Model '''
    print('\n===> Build model...')
    # encoder
    encoder = model.load_pretrained_encoder(args.arch_name, args.dataset_name,
                                            args.baseline, num_classes,
                                            device, args.relu, args.sep, pretrained=False)
    # reparameterize
    reparameterize = model.reparameterize
    # top-k operator
    topk_module = model.load_topk_module(device)
    # decoder, which is pre-trained model
    base_model = model.load_pretrained_base_model(args.arch_name, args.dataset_name,
                                                  num_classes, device)
    # if args.resume, load encoder checkpoint
    if args.resume:
        print(' ==> Resume checkpoint..')
        checkpoint = torch.load(join(args.runs_dir, 'encoder_tmp.pth'))
        encoder.load_state_dict(checkpoint['encoder'])
        best_acc = checkpoint['acc']
        start_epoch = checkpoint['epoch']

    ''' training detail '''
#     optimizer = torch.optim.SGD(encoder.parameters(), lr=1e-3, momentum=0.9, weight_decay=5e-4)
    optimizer = torch.optim.Adam(encoder.parameters(), lr=1e-5, weight_decay=5e-4)
#     scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(args.epochs/3), gamma=0.1) 
    # get prior Kappa_inv
    Kappa_inv = util.get_kappa_inv(size).to(device)
    if args.distill: 
        def criterion(outputs, teacher_outputs, mu, logvar, T=10):
            B, _, H, W = mu.size()
            KD_soft = F.kl_div(F.log_softmax(outputs/T, dim=1),
                               F.softmax(teacher_outputs/T, dim=1),
                               reduction='sum') * (T*T)
            if args.prior == 'tv':
                KLD = util.get_kl_divergence(mu, logvar, Kappa_inv)
            else:
                KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
            return (KD_soft + KLD/(H*W)/args.beta) / B
    else:
        def criterion(outputs, mu, logvar, targets):
            B, _, H, W = mu.size()
            BCE = F.cross_entropy(outputs, targets, reduction='sum')
            if args.prior == 'tv':
                KLD = util.get_kl_divergence(mu, logvar, Kappa_inv)
            else:
                KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
            return (BCE + KLD/(H*W)/args.beta) / B
        

    ''' Training '''
    def train(epoch):
        encoder.train()
        base_model.eval()
        train_loss = 0.
        correct = 0
        total = 0
        for batch_idx, (inputs, targets) in enumerate(train_loader):
            try:
                B,C,H,W = inputs.size()
                inputs, targets = inputs.to(device), targets.to(device)
                baseline = util.get_baseline(args.baseline, inputs, device)

                optimizer.zero_grad()
                topk_module.zero_grad()
                base_model.zero_grad()

                # teacher probability
                if args.distill:
                    with torch.no_grad():
                        teacher_outputs = base_model(inputs)

                # student
                mu, logvar = encoder(inputs)
    #             mu, logvar = z[:,:1,:,:], z[:,1:,:,:]
                saliency = reparameterize(mu, logvar)
                saliency_flat = saliency.view(B, -1)
                saliency_flat = util.minmax_gap(saliency_flat)

                k = torch.FloatTensor(1).uniform_(0.1,0.9).item()
                topk = topk_module(saliency_flat, k=int(H*W*k)) # (B,H*W,1) 
                topk = topk.view(B,1,H,W)

                inputs_masked = inputs * topk + baseline * (1-topk) # (B,C,H,W)
                outputs = base_model(inputs_masked)
                if args.distill:
                    loss = criterion(outputs, teacher_outputs, mu, logvar)
                else:
                    loss = criterion(outputs, mu, logvar, targets)
                loss.backward()
                optimizer.step()

                train_loss += loss.cpu().detach().item()
                _, predicted = outputs.max(1)
                total += B
                correct += predicted.eq(targets).sum().item()

                if (batch_idx+1) % 100 == 0:
                    print('Idx {:05d} | Loss: {:.03f} | Acc: {:.03f}'.format(batch_idx+1, train_loss/total, correct/total))
                    save_ckpt_tmp(epoch, 100.*correct/total)
            except RuntimeError as e:
                print(e)
                optimizer.zero_grad()
                topk_module.zero_grad()
                base_model.zero_grad()
                print(loss.detach().cpu().item())
                del inputs
                del targets
                del mu
                del logvar
                del saliency
                del saliency_flat
                del topk
                del inputs_masked
                del outputs
                del loss
                print('Continue training...')

        train_loss /= total
        correct /= total
        print('\nEpoch {:03d} | Loss: {:.03f} | Acc: {:.03f}'.format(epoch+1, train_loss, correct))

    ''' Test '''
    def test(epoch, k=None):
        global best_acc
        encoder.eval()
        base_model.eval()
        test_loss = 0.
        correct = 0
        total = 0
        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(test_loader):
                B,C,H,W = inputs.size()
                inputs, targets = inputs.to(device), targets.to(device)
                baseline = util.get_baseline(args.baseline, inputs, device)

                optimizer.zero_grad()
                topk_module.zero_grad()
                base_model.zero_grad()

                # teacher probability
                if args.distill:
                    with torch.no_grad():
                        teacher_outputs = base_model(inputs)

                # student
                mu, logvar = encoder(inputs)
                saliency = reparameterize(mu, logvar)
#                 saliency = mu
                saliency_flat = saliency.view(B, -1)
                saliency_flat = util.minmax_gap(saliency_flat)

                if k is None:
                    k = torch.FloatTensor(1).uniform_(0.1,0.9).item()
                topk = topk_module(saliency_flat, k=int(H*W*k)) # (B,H*W,1)
                topk = topk.view(B,1,H,W)

                inputs_masked = inputs * topk + baseline * (1-topk) # (B,C,H,W)
                outputs = base_model(inputs_masked)
                if args.distill:
                    loss = criterion(outputs, teacher_outputs, mu, logvar)
                else:
                    loss = criterion(outputs, mu, logvar, targets)

                test_loss += loss.cpu().detach().item()
                _, predicted = outputs.max(1)
                total += B
                correct += predicted.eq(targets).sum().item()

                if (batch_idx+1) % 100 == 0:
                    print('Idx {:04d} | Loss: {:.03f} | Acc: {:.03f}'.format(batch_idx+1, test_loss/total, correct/total))

        test_loss /= total
        correct /= total
        print('\nEpoch {:03d} | Loss: {:.03f} | Acc: {:.03f}'.format(epoch+1, test_loss, correct))

        # Save checkpoint
        acc = 100. * correct
#         if acc > best_acc and (args.dataset_name=='imagenet' or epoch>=5):
        print('Saving..')
        state = {
            'encoder': encoder.state_dict(),
            'acc': acc,
            'epoch': epoch+1,
        }
        torch.save(state, join(args.runs_dir, f'encoder_{epoch+1:02d}.pth'))
        best_acc = acc
        print('Saved!')

    def save_ckpt_tmp(epoch, acc):
        state = {
            'encoder': encoder.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        torch.save(state, join(args.runs_dir, f'encoder_tmp.pth'))

    ''' Run '''
    for epoch in range(start_epoch, start_epoch + args.epochs):
        train(epoch)
        test(epoch)
#         scheduler.step()
     
    

