import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
#import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms

import os
import argparse
import utils_quant as qt

from models import *
from utils import progress_bar

data_dir = '../../data'
exp_dir = './exps/'

relus_vgg11 = [655360000, 327680000, 163840000, 163840000, 81920000, 81920000, 20480000, 
        20480000, 40960000, 40960000]
relus_vgg16 = [655360000, 655360000, 327680000, 327680000, 163840000, 163840000, 163840000, 
        81920000, 81920000, 81920000, 20480000, 20480000, 20480000, 40960000, 40960000]
relus_vgg16_tiny = [2621440000, 2621440000, 1310720000, 1310720000, 655360000, 655360000, 655360000, 327680000, 327680000, 327680000, 81920000, 81920000, 81920000, 40960000, 40960000]

parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
parser.add_argument('--resume', '-r', action='store_true',
                    help='resume from checkpoint')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
                    help='evaluate model on validation set')
parser.add_argument('--pretrained', default='ckpt.pth', type=str, metavar='PATH',
                    help='path to pretrained model')
parser.add_argument('--save-model', dest='save_model',
                    help='Model checkpoint name',
                    default='ckpt.pth', type=str)
parser.add_argument('--train-fault', dest='train_fault', action='store_true',
                    help='training with fault')
parser.add_argument('--save-exp', dest='save_exp', default='', type=str, metavar='PATH',
                    help='label for experiment files')
parser.add_argument('--model', dest='model', default='vgg11', type=str,
                    help='Model to use')
parser.add_argument('--dataset', dest='dataset', default='c10', type=str,
                    help='Dataset to use')
parser.add_argument('--save-every', dest='save_every',
                    help='Saves checkpoints at every specified number of epochs',
                    type=int, default=10)

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'


def get_data_c10():
    # Data
    print('==> Preparing data..')
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    trainset = torchvision.datasets.CIFAR10(
        root=data_dir, train=True, download=True, transform=transform_train)
    trainloader = torch.utils.data.DataLoader(
        trainset, batch_size=128, shuffle=True, num_workers=2)

    testset = torchvision.datasets.CIFAR10(
        root=data_dir, train=False, download=True, transform=transform_test)
    testloader = torch.utils.data.DataLoader(
        testset, batch_size=100, shuffle=False, num_workers=2)

    classes = ('plane', 'car', 'bird', 'cat', 'deer',
               'dog', 'frog', 'horse', 'ship', 'truck')

    return trainloader, testloader


def get_data_c100():
    # Data
    print('==> Preparing data..')
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.507, 0.487, 0.441), (0.267, 0.256, 0.276)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.507, 0.487, 0.441), (0.267, 0.256, 0.276)),
    ])

    trainset = torchvision.datasets.CIFAR100(
        root=data_dir, train=True, download=True, transform=transform_train)
    trainloader = torch.utils.data.DataLoader(
        trainset, batch_size=128, shuffle=True, num_workers=2)

    testset = torchvision.datasets.CIFAR100(
        root=data_dir, train=False, download=True, transform=transform_test)
    testloader = torch.utils.data.DataLoader(
        testset, batch_size=100, shuffle=False, num_workers=2)

    return trainloader, testloader



def get_data_tiny():
    print('==> Preparing data..')
    train_set = torchvision.datasets.ImageFolder(
        root=data_dir+'/tiny-imagenet-200/train',
        transform=torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize([0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262])
        ])
    )

    test_set = torchvision.datasets.ImageFolder(
        root=data_dir+'/tiny-imagenet-200/val',
        transform=torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize([0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262])
        ])
    )

    train_loader = DataLoader(train_set, shuffle=True, batch_size=128, num_workers=8)

    test_loader = DataLoader(test_set, shuffle=True, batch_size=100)

    return train_loader, test_loader

# Training
def train(trainloader, net, criterion, optimizer, epoch, alpha, beta, train_fault):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        if train_fault:
            outputs = net.fault(inputs, 19, alpha, beta)
        else:
            outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                     % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))



def test(testloader, net, criterion, epoch, alpha, beta, trunc=0, train_fault=False):
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            if train_fault:
                outputs = net.fault(inputs, trunc, alpha, beta)
            else:
                outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                         % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))

    acc = 100.*correct/total
    return acc

def validate_quantize(testloader, net, criterion, trunc, alpha, beta, def_pos, stochastic):
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            if stochastic:
                outputs = net.fault(inputs, trunc, alpha, beta)
            else:
                outputs = net.quantize(inputs, trunc, def_pos, alpha, beta)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                         % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))

    return 100.*correct/total



def validate(testloader, net, criterion):
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                         % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))


def main():
    global args, exp_dir
    args = parser.parse_args()

    exp_dir = exp_dir+args.save_exp
    facc = open(exp_dir+'_acc.txt', 'w')
    ftfault = open(exp_dir+'_tfault.txt', 'w')
    fposfault = open(exp_dir+'_posfault.txt', 'w')
    fnegfault = open(exp_dir+'_negfault.txt', 'w')

    facc_dp = open(exp_dir+'_dp_acc.txt', 'w')
    ftfault_dp = open(exp_dir+'_dp_tfault.txt', 'w')
    fposfault_dp = open(exp_dir+'_dp_posfault.txt', 'w')
    fnegfault_dp = open(exp_dir+'_dp_negfault.txt', 'w')

    facc_sr = open(exp_dir+'_acc_sr.txt', 'w')
    ftfault_sr = open(exp_dir+'_tfault_sr.txt', 'w')
    fposfault_sr = open(exp_dir+'_posfault_sr.txt', 'w')
    fnegfault_sr = open(exp_dir+'_negfault_sr.txt', 'w')
    
    start_epoch = 0  # start from epoch 0 or last checkpoint epoch

    if args.resume:
        # Load checkpoint.
        print('==> Resuming from checkpoint..')
        assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
        checkpoint = torch.load('./checkpoint/ckpt.pth')
        net.load_state_dict(checkpoint['net'])
        best_acc = checkpoint['acc']
        start_epoch = checkpoint['epoch']

    
    if args.dataset == 'c10':
        train_loader, test_loader = get_data_c10()
        num_classes=10
    elif args.dataset == 'c100':
        train_loader, test_loader = get_data_c100()
        num_classes=100
    elif args.dataset == 'tiny':
        train_loader, test_loader = get_data_tiny()
        num_classes=200
    else:
        print ('specify dataset')
        exit()

    
    # Model
    print('==> Building model..')
    # net = VGG('VGG19')
    # net = ResNet18()
    # net = PreActResNet18()
    # net = GoogLeNet()
    # net = DenseNet121()
    # net = ResNeXt29_2x64d()
    # net = MobileNet()
    # net = MobileNetV2()
    # net = DPN92()
    # net = ShuffleNetG2()
    # net = SENet18()
    # net = ShuffleNetV2(1)
    # net = EfficientNetB0()
    # net = RegNetX_200MF()
    #net = SimpleDLA()
    if args.model == 'vgg11':
        net = VGG11(num_classes)
        length=10
        alpha=4
        beta=0
    elif args.model == 'vgg16' and args.dataset != 'tiny':
        net = VGG16(num_classes)
        length=15
        alpha=4
        beta=1
    elif args.model == 'vgg16' and args.dataset == 'tiny':
        net = VGG16(num_classes)
        length=15
        alpha=5
        beta=0
    else:
        net = VGG19(num_classes)
        length=18
    
    net = net.to(device)
    #if device == 'cuda':
    #    net = torch.nn.DataParallel(net)
    #    cudnn.benchmark = True

    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=args.lr,
                          momentum=0.9, weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)


    if args.evaluate:
        checkpoint = torch.load('./checkpoint/'+args.pretrained, map_location=torch.device(device))
        net.load_state_dict(checkpoint['net'])
        best_acc = checkpoint['acc']
        print("saved chkp acc = ", best_acc)

        validate(test_loader, net, criterion)
        #qt.analyze_model(net)
        #return

        print("alpha={}, beta={}".format(alpha, beta))


        # with srelu
        #for trunc in range(5,29):
        #    print("trunc bits = {}".format(trunc))
        #    truncs = [trunc]*length
        #    acc = validate_quantize(test_loader, net, criterion, trunc, alpha, beta, def_pos=False, stochastic=True)
        #    facc_sr.write(str(acc)+'\n')

        #    for i in range(length):
        #        total = net.neg_sr[i].item()+net.pos_sr[i].item()
        #        signerr = ((net.bspos_sr[i].item()+net.bsneg_sr[i].item())/total)*100
        #        signerr_pos = (net.bspos_sr[i].item()/net.pos_sr[i].item())*100 if net.pos_sr[i].item()!=0 else 0
        #        signerr_neg = (net.bsneg_sr[i].item()/net.neg_sr[i].item())*100 if net.neg_sr[i].item()!=0 else 0
        #        neg = (net.neg_sr[i].item()/total)*100
        #        pos = (net.pos_sr[i].item()/total)*100
        #        print("ReLU{}\tsign err= {:.2f}, pos err= {:.2f}, neg err= {:.2f}, negs = {:.2f}%"
        #            .format(i+1, signerr, signerr_pos, signerr_neg, neg))
        #        
        #        ftfault_sr.write(str(signerr)+'\n')
        #        fposfault_sr.write(str(signerr_pos)+'\n')
        #        fnegfault_sr.write(str(signerr_neg)+'\n')

        #    net.reset_arrays()

        #srelu
        print("fp model with srelu (def neg)")
        validate_quantize(test_loader, net, criterion, 18, alpha, beta, def_pos=False, stochastic=True)

        # with appx relu
        qt.scale_params(net, alpha, beta)

        print("int model with relu")
        validate_quantize(test_loader, net, criterion, [-1]*length, alpha, beta, def_pos=False, stochastic=False)
        print("arelu def pos")
        validate_quantize(test_loader, net, criterion, [19]*length, alpha, beta, def_pos=True, stochastic=False)
        print("arelu def neg")
        validate_quantize(test_loader, net, criterion, [19]*length, alpha, beta, def_pos=False, stochastic=False)
        return

        ## hetro exps
        #trunc=21
        ##truncs = [trunc]*length
        #truncs = [21, 21, 21, 22, 21, 20, 21, 20, 20, 20]
        #bits=[31-x for x in truncs]*length
        #cost = sum([x*y for x,y in zip(bits, relus)])
        #print ("cost = {}".format(cost))
        #acc = validate_quantize(test_loader, net, criterion, truncs, alpha, beta, def_pos=False, stochastic=False)

        #for i in range(length):
        #    total = net.neg[i].item()+net.pos[i].item()
        #    signerr = ((net.badsign_pos[i].item()+net.badsign_neg[i].item())/total)*100
        #    signerr_pos = (net.badsign_pos[i].item()/net.pos[i].item())*100
        #    signerr_neg = (net.badsign_neg[i].item()/net.neg[i].item())*100
        #    neg = (net.neg[i].item()/total)*100
        #    pos = (net.pos[i].item()/total)*100
        #    print("ReLU{}\tsign err= {:.2f}, pos err= {:.2f}, neg err= {:.2f}, negs = {:.2f}%"
        #        .format(i+1, signerr, signerr_pos, signerr_neg, neg))
        #print("\n")
        #return


        # default negative
        for trunc in range(5,29):
            print("trunc bits = {}".format(trunc))
            truncs = [trunc]*length
            acc = validate_quantize(test_loader, net, criterion, truncs, alpha, beta, 
                    def_pos=False, stochastic=False)
            #qt.analyze_model(net)
            
            # write to file
            facc.write(str(acc)+'\n')

            for i in range(length):
                total = net.neg[i].item()+net.pos[i].item()
                signerr = ((net.badsign_pos[i].item()+net.badsign_neg[i].item())/total)*100
                signerr_pos = (net.badsign_pos[i].item()/net.pos[i].item())*100 if net.pos[i].item()!=0 else 0
                signerr_neg = (net.badsign_neg[i].item()/net.neg[i].item())*100 if net.neg[i].item()!=0 else 0
                neg = (net.neg[i].item()/total)*100
                pos = (net.pos[i].item()/total)*100
                print("ReLU{}\tsign err= {:.2f}, pos err= {:.2f}, neg err= {:.2f}, negs = {:.2f}%"
                    .format(i+1, signerr, signerr_pos, signerr_neg, neg))

                # write to file
                ftfault.write(str(signerr)+'\n')
                fposfault.write(str(signerr_pos)+'\n')
                fnegfault.write(str(signerr_neg)+'\n')
            print("\n")

            net.reset_arrays()

        # default positive
        for trunc in range(5,29):
            print("trunc bits = {}".format(trunc))
            truncs = [trunc]*length
            acc = validate_quantize(test_loader, net, criterion, truncs, alpha, beta, 
                    def_pos=True, stochastic=False)
            
            # write to file
            facc_dp.write(str(acc)+'\n')

            for i in range(length):
                total = net.neg[i].item()+net.pos[i].item()
                signerr = ((net.badsign_pos[i].item()+net.badsign_neg[i].item())/total)*100
                signerr_pos = (net.badsign_pos[i].item()/net.pos[i].item())*100 if net.pos[i].item()!=0 else 0
                signerr_neg = (net.badsign_neg[i].item()/net.neg[i].item())*100 if net.neg[i].item()!=0 else 0
                neg = (net.neg[i].item()/total)*100
                pos = (net.pos[i].item()/total)*100
                print("ReLU{}\tsign err= {:.2f}, pos err= {:.2f}, neg err= {:.2f}, negs = {:.2f}%"
                    .format(i+1, signerr, signerr_pos, signerr_neg, neg))

                # write to file
                ftfault_dp.write(str(signerr)+'\n')
                fposfault_dp.write(str(signerr_pos)+'\n')
                fnegfault_dp.write(str(signerr_neg)+'\n')
            print("\n")

            net.reset_arrays()



        facc.close() 
        ftfault.close() 
        fposfault.close() 
        fnegfault.close() 


        return

    
    for epoch in range(start_epoch, start_epoch+200):
        train(train_loader, net, criterion, optimizer, epoch, alpha, beta, args.train_fault)
        acc = test(test_loader, net, criterion, epoch, 19, alpha, beta, args.train_fault)
        scheduler.step()

    
        if epoch >0 and epoch % args.save_every == 0:
            print('Saving..')
            state = {
                'net': net.state_dict(),
                'acc': acc,
                'epoch': epoch,
            }
            if not os.path.isdir('checkpoint'):
                os.mkdir('checkpoint')
            torch.save(state, './checkpoint/'+args.save_model)

    print('Saving..')
    state = {
        'net': net.state_dict(),
        'acc': acc,
        'epoch': epoch,
    }
    torch.save(state, './checkpoint/'+args.save_model)




if __name__ == '__main__':
    main()
