# -*- coding: utf-8 -*-
import numpy as np
import os
import pickle
import argparse
import time
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torchvision.transforms as trn
import torchvision.transforms.functional as trnF
import torchvision.datasets as dset
import torchvision.transforms as transforms
from models.wrn import *
import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo
import math
import random
from cifar_utils import *
from loss.loss import *
import dataset.cifar10_64_new as dataset

parser = argparse.ArgumentParser(description='Trains a one-class model',
                                 formatter_class=argparse.ArgumentDefaultsHelpFormatter)
# Optimization options
parser.add_argument('--epochs', '-e', type=int, default=1000, help='Number of epochs to train.')
parser.add_argument('--learning_rate', '-lr', type=float, default=1e-4, help='The initial learning rate.')
parser.add_argument('--batch_size', '-b', type=int, default=64, help='Batch size.')
parser.add_argument('--test_bs', type=int, default=200)
parser.add_argument('--momentum', type=float, default=0.9, help='Momentum.')
parser.add_argument('--decay', '-d', type=float, default=0.0005, help='Weight decay (L2 penalty).')
# Checkpoints
parser.add_argument('--save', '-s', type=str, default='./snapshots/',
                    help='Folder to save checkpoints.')
parser.add_argument('--test', '-t', action='store_true', help='Test only flag.')
# WRN Architecture
parser.add_argument('--layers', default=16, type=int, help='total number of layers')
parser.add_argument('--widen-factor', default=4, type=int, help='widen factor')
parser.add_argument('--droprate', default=0.3, type=float, help='dropout probability')
# Acceleration
parser.add_argument('--ngpu', type=int, default=1, help='0 = CPU.')
parser.add_argument('--prefetch', type=int, default=10, help='Pre-fetching threads.')

parser.add_argument('--beta1', type=float, default=0.5)
parser.add_argument('--beta2', type=float, default=0.999)
parser.add_argument('--numb_class', default=10, type=int, help='output class')

parser.add_argument('--begin_first', type=int, default=50, help='When to begin updating labels')
parser.add_argument('--alpha', type=float, default=0.8, help='Hyper parameter alpha of loss function')
parser.add_argument('--beta', type=float, default=0.4, help='Hyper parameter beta of loss function')
parser.add_argument('--percent', type=float, default=0.6, help='Percentage of noise')
args = parser.parse_args()

state = {k: v for k, v in args._get_kwargs()}
print(state)

torch.manual_seed(1)
np.random.seed(1)


# Data
print('==> Preparing data..') 

# mean and standard deviation of channels of CIFAR-10 images

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=8),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),])

transform_val = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),])

trainset, valset = dataset.get_cifar10('../data/cifar10', args, train=True, download=True, transform_train=transform_train, transform_val=transform_val)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=False, num_workers=4)
testloader = torch.utils.data.DataLoader(valset, batch_size=200, shuffle=False, num_workers=4)



def mycriterion(outputs, soft_targets, prop, epoch):
    # We introduce a prior probability distribution p, which is a distribution of classes among all training data.
    p = torch.ones(10).cuda() / 10
    probs = F.softmax(outputs, dim=1)
    avg_probs = torch.mean(probs, dim=0)

    L_c = -torch.mean(torch.sum(F.log_softmax(outputs, dim=1) * soft_targets, dim=1))
    L_p = -torch.sum(torch.log(avg_probs) * p)
    L_e = -torch.mean(torch.sum(F.log_softmax(outputs, dim=1) * probs, dim=1))

    loss_total = L_c + args.alpha * L_p + args.beta * L_e
    loss_prop = -torch.sum(torch.log(avg_probs) * prop)

    # RCE
    probs = torch.clamp(probs, min=1e-7, max=1.0)
    soft_targets = torch.clamp(soft_targets, min=1e-4, max=1.0)
    rce = -torch.mean((torch.sum(probs * torch.log(soft_targets), dim=1)))

    # Loss
    loss_new = L_c + rce

    return probs, loss_new


# Create model
# net = resnet18()
# net = ResidualNet('ImageNet', 18, output_dim, 'CBAM')
# net = WideResNet(args.layers, 10, args.widen_factor, dropRate=args.droprate)
net = LLPNet(args.numb_class)
#net = resnet34()
#net_st = resnet34()

if args.ngpu > 1:
    net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))

if args.ngpu > 0:
    net.cuda()
    torch.cuda.manual_seed(1)

cudnn.benchmark = True  # fire on all cylinders


optimizer = torch.optim.Adam(net.parameters(), lr=args.learning_rate, betas=(args.beta1, args.beta2))
criterion = nn.CrossEntropyLoss()
# /////////////// Training ///////////////


def adjust_learning_rate(optimizer, epoch):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    lr = args.learning_rate
    if epoch >= 100:
        lr = 1e-4 * (0.5 ** (epoch // 100))
        print (lr)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


#Testing 
def test(loader, net):

    correct = 0
    total = 0
    net.eval()

    with torch.no_grad():
        for (data, target) in loader:
            data, target = data.cuda(), target.cuda()
            outputs = net(data)
            _, predicted = torch.max(outputs.data, 1)

            total += target.size(0)
            correct += (predicted == target).sum().item()

    return  100*correct/total


def training_accuracy(loader, net):
    correct = 0
    total = 0
    net.eval()
    with torch.no_grad():
        for (data, target, target_real, _, _) in loader:
            data, target_real = data.cuda(), target_real.cuda()
            outputs = net(data)
            _, predicted = torch.max(outputs.data, 1)
            total += target_real.size(0)
            correct += (predicted == target_real).sum().item()
    return  100*correct/total



def training_accuracy_fake(loader, net):
    correct = 0
    total = 0
    net.eval()
    with torch.no_grad():
        for (data, target, target_real, _, _) in loader:
            target_real, target = target_real.cuda(), target.cuda()
            total += target.size(0)
            correct += (target_real == target).sum().item()
    return  100*correct/total



# Make save directory
if not os.path.exists(args.save):
    os.makedirs(args.save)
if not os.path.isdir(args.save):
    raise Exception('%s is not a dir' % args.save)

results = np.zeros((len(trainloader.dataset), 10), dtype=np.float32)

print('Beginning Training\n')


save_table = np.zeros(shape=(args.epochs, 3))
# Main loop
for epoch in range(0, args.epochs):

    correct = 0
    total = 0
    begin_epoch = time.time()
    adjust_learning_rate(optimizer, epoch)
    for batch_idx, (inputs, targets, targets_real, soft_targets, indexs) in enumerate(trainloader):
        inputs, targets, indexs = inputs.cuda(), targets.cuda(), indexs.cuda()
        targets_real, soft_targets = targets_real.cuda(), soft_targets.cuda()
        #forword
        prop = torch.FloatTensor(np.bincount(targets_real.detach().cpu().numpy(),minlength=10)/inputs.size(0)).cuda()      
        #results[indexs.cpu().detach().numpy().tolist()] = probs.cpu().detach().numpy().tolist()

        outputs = net(inputs)
        probs, loss = mycriterion(outputs, soft_targets, prop, epoch)
        results_temp = optimize_ot(probs.cpu().detach().numpy(), prop.cpu().detach().numpy())
        results[indexs.cpu().detach().numpy().tolist()] = results_temp
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        _, predicted = torch.max(torch.FloatTensor(results_temp).cuda(), 1)
        total += inputs.size(0)
        correct += (predicted == targets_real).sum().item()


    updata_accuracy = 100*correct/total


    test_accuracy = test(testloader,net)
    #train_accuracy = training_accuracy(trainloader,net)
    train_accuracy = 0
    train_accuracy_fake = training_accuracy_fake(trainloader,net)
    #updata the labels
    trainloader.dataset.label_update(results)

    f = open('test_cifar10_64_hard.txt', 'a')
    f.write(str(test_accuracy) + '\n')
    f.close()


    #Save model
    #torch.save(net.state_dict(),
               #os.path.join(args.save, + '_' + str(epoch) + '.pt'))
    #Let us not waste space and delete the previous model
    #prev_path = os.path.join(args.save, + '_' + str(epoch - 1) + '.pt')
    #if os.path.exists(prev_path): os.remove(prev_path)

    #Show results
    print('Epoch {0:3d} | Time {1:5d} | test_accuracy {2:4f} | train_accuracy {3:4f} |  train_accuracy_fake {4:4f} | updata_accuracy {5:4f} | loss {6:4f} | '.format(
        (epoch + 1),
        int(time.time() - begin_epoch),
        test_accuracy,
        train_accuracy,
        train_accuracy_fake,
        updata_accuracy,
        loss
    ))
    save_table[epoch, :] = epoch+1, test_accuracy, loss

np.savetxt('train_self_results.txt', save_table)
