# -*- coding: utf-8 -*-
import numpy as np
import os
import pickle
import argparse
import time
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torchvision.transforms as trn
import torchvision.datasets as dset
import torch.nn.functional as F
from tqdm import tqdm
from models.allconv import AllConvNet
from models.wrn import WideResNet
import torch.distributions as dist
from torch.distributions.dirichlet import Dirichlet
import random
if __package__ is None:
    import sys
    from os import path

    sys.path.append(path.dirname(path.dirname(path.abspath(__file__))))
    from utils.imagenet_rc_loader import ImageNet
    from utils.tin597_loader import TIN597
    from utils.validation_dataset import validation_split

parser = argparse.ArgumentParser(description='Tunes a CIFAR Classifier with ENERGY',
                                 formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--dataset', type=str, default='cifar100', choices=['cifar10', 'cifar100'],
                    help='Choose between CIFAR-10, CIFAR-100.')
parser.add_argument('--aux', type=str, default='imagenet', choices=['imagenet', 'TIN597'],
                    help='Choose between imanaget,TIN597.')
parser.add_argument('--model', '-m', type=str, default='wrn',
                    choices=['allconv', 'wrn'], help='Choose architecture.')
parser.add_argument('--calibration', '-c', action='store_true',
                    help='Train a model to be used for calibration. This holds out some data for validation.')
# EG specific
#parser.add_argument('--m_in', type=float, default=-23., help='margin for in-distribution; above this value will be penalized')
#parser.add_argument('--m_out', type=float, default=-7., help='margin for out-distribution; below this value will be penalized')
parser.add_argument('--m_in', type=float, default=-430., help='margin for in-distribution; above this value will be penalized')
parser.add_argument('--m_out', type=float, default=-370., help='margin for out-distribution; below this value will be penalized')
parser.add_argument('--lamb', type=float, default=0.1, help='margin for out-distribution; below this value will be penalized')

# Optimization options
parser.add_argument('--epochs', '-e', type=int, default=20, help='Number of epochs to train.')
#parser.add_argument('--learning_rate', '-lr', type=float, default=0.0002, help='The initial learning rate.')
parser.add_argument('--learning_rate', '-lr', type=float, default=0.0001, help='The initial learning rate.')
parser.add_argument('--batch_size', '-b', type=int, default=92, help='Batch size.')
parser.add_argument('--dpn_batch_size', type=int, default=184, help='Batch size.')
parser.add_argument('--test_bs', type=int, default=200)
parser.add_argument('--momentum', type=float, default=0.9, help='Momentum.')
parser.add_argument('--decay', '-d', type=float, default=0.0005, help='Weight decay (L2 penalty).')
# WRN Architecture
parser.add_argument('--layers', default=40, type=int, help='total number of layers')
parser.add_argument('--widen-factor', default=10, type=int, help='widen factor')
parser.add_argument('--droprate', default=0.3, type=float, help='dropout probability')
# Checkpoints
parser.add_argument('--save', '-s', type=str, default='./snapshots/dpn_tune_new', help='Folder to save checkpoints.')
parser.add_argument('--load', '-l', type=str, default='./snapshots/baseline', help='Checkpoint path to resume / test.')
parser.add_argument('--test', '-t', action='store_true', help='Test only flag.')
# Acceleration
parser.add_argument('--ngpu', type=int, default=1, help='0 = CPU.')
parser.add_argument('--prefetch', type=int, default=2, help='Pre-fetching threads.')
parser.add_argument('--seed', type=int, default=1, help='Random seed for reproducibility')
args = parser.parse_args()

def set_random_seed(seed):
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.cuda.manual_seed(seed) 
    torch.backends.cudnn.deterministic = True 
    torch.backends.cudnn.benchmark = False 


state = {k: v for k, v in args._get_kwargs()}
print(state)
set_random_seed(args.seed)

# mean and standard deviation of channels of CIFAR-10 images
mean = [x / 255 for x in [125.3, 123.0, 113.9]]
std = [x / 255 for x in [63.0, 62.1, 66.7]]

train_transform = trn.Compose([trn.RandomHorizontalFlip(), trn.RandomCrop(32, padding=4),
                               trn.ToTensor(), trn.Normalize(mean, std)])
test_transform = trn.Compose([trn.ToTensor(), trn.Normalize(mean, std)])

if args.dataset == 'cifar10':
    train_data_in = dset.CIFAR10('/datasets/cifar10/', train=True, transform=train_transform)
    test_data = dset.CIFAR10('/datasets/cifar10/', train=False, transform=test_transform)
    num_classes = 10
else:
    train_data_in = dset.CIFAR100('/datasets/cifar100/', train=True, download=True, transform=train_transform)
    test_data = dset.CIFAR100('/datasets/cifar100/', train=False, download=True, transform=test_transform)
    num_classes = 100


calib_indicator = ''
if args.calibration:
    train_data_in, val_data = validation_split(train_data_in, val_share=0.1)
    calib_indicator = '_calib'

if args.aux == 'imagenet':
    ood_data = ImageNet(transform=trn.Compose(
        [trn.ToTensor(), trn.ToPILImage(), trn.RandomCrop(32, padding=4),
        trn.RandomHorizontalFlip(), trn.ToTensor(), trn.Normalize(mean, std)]))
    shuffle = False
elif args.aux == 'TIN597':
    ood_data = TIN597(root = '/dataset/tin597/test',
        transform=trn.Compose(
        [trn.ToTensor(), trn.ToPILImage(), trn.RandomCrop(32, padding=4),
        trn.RandomHorizontalFlip(), trn.ToTensor(), trn.Normalize(mean, std)]))
    shuffle = True


train_loader_in = torch.utils.data.DataLoader(
    train_data_in,
    batch_size=args.batch_size, shuffle=True,
    num_workers=args.prefetch, pin_memory=True)

train_loader_out = torch.utils.data.DataLoader(
    ood_data,
    batch_size=args.dpn_batch_size, shuffle=shuffle,
    num_workers=args.prefetch, pin_memory=True)

test_loader = torch.utils.data.DataLoader(
    test_data,
    batch_size=args.batch_size, shuffle=False,
    num_workers=args.prefetch, pin_memory=True)

# Create model
if args.model == 'allconv':
    net = AllConvNet(num_classes)
else:
    net = WideResNet(args.layers, num_classes, args.widen_factor, dropRate=args.droprate)
    original_net = WideResNet(args.layers, num_classes, args.widen_factor, dropRate=args.droprate)

# Restore model
model_found = False
if args.load != '':
    for i in range(1000 - 1, -1, -1):
        model_name = os.path.join(args.load, args.dataset + calib_indicator + '_' + args.model + str(args.seed) + 
                                  '_baseline_epoch_' + str(i) + '.pt')
        if os.path.isfile(model_name):
            print (model_name)
            net.load_state_dict(torch.load(model_name))
            original_net.load_state_dict(torch.load(model_name))
            print('Model restored! Epoch:', i)
            model_found = True
            break
    if not model_found:
        assert False, "could not find model to restore"  

if args.ngpu > 1:
    net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))

if args.ngpu > 0:
    net.cuda()
    original_net.cuda()

cudnn.benchmark = True  # fire on all cylinders

optimizer = torch.optim.SGD(
    net.parameters(), state['learning_rate'], momentum=state['momentum'],
    weight_decay=state['decay'], nesterov=True)


def cosine_annealing(step, total_steps, lr_max, lr_min):
    return lr_min + (lr_max - lr_min) * 0.5 * (
            1 + np.cos(step / total_steps * np.pi))


scheduler = torch.optim.lr_scheduler.LambdaLR(
    optimizer,
    lr_lambda=lambda step: cosine_annealing(
        step,
        args.epochs * len(train_loader_in),
        1,  # since lr_lambda computes multiplicative factor
        1e-6 / args.learning_rate))


def diff_entropy(alphas):
    alpha0 = torch.sum(alphas, dim=1)
    return torch.sum(
            torch.lgamma(alphas)-(alphas-1)*(torch.digamma(alphas)-torch.digamma(alpha0).unsqueeze(1)),
            dim=1) - torch.lgamma(alpha0)

def target_2_alpha(target, num_classes=100):
    alpha = torch.ones((target.shape[0], num_classes)).cuda()
    alpha.scatter_(1, target.unsqueeze(1), 12)
    return alpha

# def target_2_alpha(target, num_classes=10):
#     alpha = torch.ones((target.shape[0], num_classes)).cuda()
#     alpha.scatter_(1, target.unsqueeze(1), 15)
#     return alpha

# /////////////// Training ///////////////

def train():
    net.train()  # enter train mode
    original_net.eval()
    loss_avg = 0.0

    # start at a random point of the outlier dataset; this induces more randomness without obliterating locality
    train_loader_out.dataset.offset = np.random.randint(len(train_loader_out.dataset))
    for in_set, out_set in zip(train_loader_in, train_loader_out):
        data = torch.cat((in_set[0], out_set[0]), 0)
        target = in_set[1]

        data, target = data.cuda(), target.cuda()

        # forward
        x = net(data)
        alpha = F.relu(x)+1

        ood_logits = original_net(out_set[0].cuda())
        ood_alpha = F.relu(ood_logits) + 1
        #ood_pred = torch.max(ood_logits, dim=1)[1]

        # backward
        optimizer.zero_grad()
        # print ("x[:len(in_set[0])].shape",x[:len(in_set[0])].shape)
        # print ("target.shape",target.shape)
        clf_loss = F.cross_entropy(x[:len(in_set[0])], target)
        
        dirichlet1 = Dirichlet(alpha[:len(in_set[0])])
        dirichlet2 = Dirichlet(target_2_alpha(target))
        clf_loss += torch.mean(dist.kl.kl_divergence(dirichlet1, dirichlet2))

        h_in = diff_entropy(alpha[:len(in_set[0])])
        h_out = diff_entropy(alpha[len(in_set[0]):])
        diff_entropy_loss = args.lamb *(torch.pow(F.relu(h_in-args.m_in), 2).mean() + torch.pow(F.relu(args.m_out-h_out), 2).mean())
        #ood_clf_loss = 0.5 * F.cross_entropy(x[len(in_set[0]):], ood_pred)
        dirichlet1 = Dirichlet(alpha[len(in_set[0]):])
        dirichlet2 = Dirichlet(ood_alpha)
        ood_clf_loss = torch.mean(dist.kl.kl_divergence(dirichlet1, dirichlet2))
        #ood_clf_loss += F.cross_entropy(x[len(in_set[0]):], ood_pred)

        loss = clf_loss + diff_entropy_loss + ood_clf_loss
        #loss = clf_loss + diff_entropy_loss

        loss.backward()
        optimizer.step()
        scheduler.step()

        # exponential moving average
        loss_avg = loss_avg * 0.8 + float(loss) * 0.2

    state['train_loss'] = loss_avg


# test function
def test():
    net.eval()
    loss_avg = 0.0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.cuda(), target.cuda()

            # forward
            output = net(data)
            loss = F.cross_entropy(output, target)

            # accuracy
            pred = output.data.max(1)[1]
            correct += pred.eq(target.data).sum().item()

            # test loss average
            loss_avg += float(loss.data)

    state['test_loss'] = loss_avg / len(test_loader)
    state['test_accuracy'] = correct / len(test_loader.dataset)


if args.test:
    test()
    print(state)
    exit()
# args.save = f"./snapshots/dpn_tune/{args.aux}_{args.m_in}_{args.m_out}_{args.learning_rate}_{args.epochs}_{args.lamb}"
# Make save directory
if not os.path.exists(args.save):
    os.makedirs(args.save)
if not os.path.isdir(args.save):
    raise Exception('%s is not a dir' % args.save)

with open(os.path.join(args.save, args.dataset + calib_indicator + '_' + args.model + str(args.seed) +
                                  '_dpn_tune_training_results.csv'), 'w') as f:
    f.write('epoch,time(s),train_loss,test_loss,test_error(%)\n')

print('Beginning Training\n')

# Main loop
for epoch in range(0, args.epochs):
    state['epoch'] = epoch

    begin_epoch = time.time()

    train()
    test()

    # Save model
    torch.save(net.state_dict(),
               os.path.join(args.save, args.dataset + calib_indicator + '_' + args.model + str(args.seed) +
                            '_dpn_tune_epoch_' + str(epoch) + '.pt'))
    # Let us not waste space and delete the previous model
    prev_path = os.path.join(args.save, args.dataset + calib_indicator + '_' + args.model + str(args.seed) +
                             '_dpn_tune_epoch_' + str(epoch - 1) + '.pt')
    if os.path.exists(prev_path): os.remove(prev_path)

    # Show results

    with open(os.path.join(args.save, args.dataset + calib_indicator + '_' + args.model + str(args.seed) +
                                      '_dpn_tune_training_results.csv'), 'a') as f:
        f.write('%03d,%05d,%0.6f,%0.5f,%0.2f\n' % (
            (epoch + 1),
            time.time() - begin_epoch,
            state['train_loss'],
            state['test_loss'],
            100 - 100. * state['test_accuracy'],
        ))

    # # print state with rounded decimals
    # print({k: round(v, 4) if isinstance(v, float) else v for k, v in state.items()})

    print('Epoch {0:3d} | Time {1:5d} | Train Loss {2:.4f} | Test Loss {3:.3f} | Test Error {4:.2f}'.format(
        (epoch + 1),
        int(time.time() - begin_epoch),
        state['train_loss'],
        state['test_loss'],
        100 - 100. * state['test_accuracy'])
    )
