"""
    Example showing how to use the DoG optimizer with CIFAR-100 and ResNet18.
    Based on the provided MNIST example.
"""

from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms, models
from dog import DoG, LDoG, PolynomialDecayAverager

import trackexp as tx

import pyfamilywise


def train(args, model, averager, device, train_loader, optimizer, epoch):
    model.train()

    # Define loss function
    if args.loss_func == 'CE':
        criterion = nn.CrossEntropyLoss()
    elif args.loss_func == 'FW':
        criterion = pyfamilywise.FWLoss(num_classes=100, device=device)
    else:
        print("Loss not supported!")
        assert(False)

    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)

        if args.track_kappa_bar:
            z = output
            num_classes = output.shape[1]
            kappa_inv = 1/torch.arange(1, num_classes+1)

            def get_kappa_bars(z):
                z_hnscs = torch.sort(z, axis=-1, descending=True).values
                z_hnscs.cumsum_(axis=-1)
                z_hnscs.sub_(1)
                z_hnscs.mul_(kappa_inv.to(target.device)[torch.newaxis,:])
                _, idxs = torch.max(z_hnscs, axis=-1)
    
                kappa_bars = idxs.cpu().detach().numpy()
                return kappa_bars
            
            import numpy as np
            def sparsity_estimator(sigma):
                return (1/sigma) * np.sqrt(2 * np.log(sigma))
            z_np = z[0].detach().cpu().numpy()
            sigma_est= num_classes*np.sqrt(np.mean((z_np-np.mean(z_np))**2))
            kappa_bar_est = num_classes * sparsity_estimator(sigma_est)

            
            kappa_bars= get_kappa_bars(z)
            if batch_idx % args.log_interval == 0:
                tx.log("training_inner", "kappa_bar_mean", (tx.saved_vars['epoch'], batch_idx), kappa_bars.mean().item())
                tx.log("training_inner", "kappa_bar_0_est", (tx.saved_vars['epoch'], batch_idx), kappa_bar_est)
                tx.log("training_inner", "kappa_bar_0", (tx.saved_vars['epoch'], batch_idx), kappa_bars[0])

                tx.saved_vars['batch_idx'] = batch_idx
                
            
        
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        averager.step()
        if batch_idx % args.log_interval == 0:
            tx.log("training_inner", "loss", (tx.saved_vars['epoch'], batch_idx), loss.item())
            tx.saved_vars['batch_idx'] = batch_idx

            if args.log_state and isinstance(optimizer, DoG):
                opt_state = optimizer.state_dict()
                for i, p in enumerate(opt_state['param_groups']):
                    prefix = f"DoG's state for param group {i}" if not args.ldog else \
                        f"LDoG's state for param group {i} (mean values across layers)"
                    format_pgroup_dog_state(p)


def format_pgroup_dog_state(dog_param_group_state):
    """
    A helper function to format the state of a DoG parameter group into a loggable string,
    describing the distance from initial point, the sum of gradient squared norms, and the step size.

    Note: for LDoG, those value are the mean across layers
    @param dog_param_group_state: A state_dict of a single param group of a DoG optimizer
    @return: A printable string
    """
    rbar = torch.mean(dog_param_group_state['rbar'].detach()).item()
    G = torch.mean(dog_param_group_state['G'].detach()).item()
    # in DoG, eta has the same value for all layers
    eta = torch.mean(torch.stack(dog_param_group_state['eta'])).detach().item()
    tx.log("training_inner", "rbar", (tx.saved_vars['epoch'], tx.saved_vars['batch_idx']), rbar)
    tx.log("training_inner", "G", (tx.saved_vars['epoch'], tx.saved_vars['batch_idx']), G)
    tx.log("training_inner", "eta", (tx.saved_vars['epoch'], tx.saved_vars['batch_idx']), eta)
    return f'rbar={rbar:E}, G={G:E}, eta={eta:E}'


def test(model, device, test_loader, model_name):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)

    tx.log("training", f"test_loss_{model_name}", tx.saved_vars['epoch'], test_loss)
    tx.log("training", f"test_acc_{model_name}", tx.saved_vars['epoch'], accuracy)

    print('\nTest set ({}): Loss = {:.4f}, Accuracy = {:.2f}% ({}/{})\n'.format(
        model_name,
        test_loss,
        accuracy,
        correct,
        len(test_loader.dataset),
        ))


def main():
    parser = argparse.ArgumentParser(description='PyTorch CIFAR-100 with ResNet18 Example')
    parser.add_argument('--data-root', type=str, default='../data', metavar='N',
                        help='data root (default: "../data")')
    parser.add_argument('--batch-size', type=int, default=128, metavar='N',
                        help='input batch size for training (default: 128)')
    parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs', type=int, default=3, metavar='N',
                        help='number of epochs to train (default: 200)')
    parser.add_argument('--ldog', action='store_true', default=False,
                        help='If set to true, will use LDoG rather than DoG')
    parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
                        help='base learning rate (default: 1.0) - should not be changed!')
    parser.add_argument('--reps_rel', type=float, default=1e-6, metavar='M',
                        help='normalized version of the r_epsilon parameter (default: 1e-6)')
    parser.add_argument('--init_eta', type=float, default=0, metavar='M',
                        help='if above 0, will use this value as the initial eta instead of the result of '
                             'reps_rel (default: 0)')
    parser.add_argument('--avg_gamma', type=float, default=8, metavar='M',
                        help='Polynomial decay averager gamma (default: 8)')
    parser.add_argument('--weight_decay', type=float, default=5e-4, metavar='M',
                        help='weight decay coefficient (default: 5e-4)')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                        help='how many batches to wait before logging training status')
    parser.add_argument('--save-model', action='store_true', default=False,
                        help='Save the current Model')
    parser.add_argument('--no-log-state', action='store_false', default=True, dest='log_state',
                        help='Suppress logging the state of the optimizer at each log_interval')
    parser.add_argument('--loss_func', type=str, default='CE', help='Loss function (default: CE)')
    parser.add_argument('--trackexp_dir', type=str, default='', help='Output directory for trackexp')
    parser.add_argument('--track_kappa_bar', action='store_true', default=False, help='do we track kappa_bar?')

    args = parser.parse_args()

    if len(args.trackexp_dir) == 0:
        tx.init(verbose=False)
    else:
        tx.init(args.trackexp_dir, verbose=False)
    tx.metadata(args.__dict__)
    use_cuda = torch.cuda.is_available()

    torch.manual_seed(args.seed)

    if use_cuda:
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")

    train_kwargs = {'batch_size': args.batch_size}
    test_kwargs = {'batch_size': args.test_batch_size}
    if use_cuda:
        cuda_kwargs = {'num_workers': 4,
                       'pin_memory': True,
                       'shuffle': True}
        train_kwargs.update(cuda_kwargs)
        test_kwargs.update(cuda_kwargs)

    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
    ])

    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
    ])

    train_dataset = datasets.CIFAR100(args.data_root, train=True, download=True, transform=train_transform)
    test_dataset = datasets.CIFAR100(args.data_root, train=False, transform=test_transform)

    train_loader = torch.utils.data.DataLoader(train_dataset, **train_kwargs)
    test_loader = torch.utils.data.DataLoader(test_dataset, **test_kwargs)

    model = models.resnet18(weights=None)
    model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    model.maxpool = nn.Identity()
    model.fc = nn.Linear(model.fc.in_features, 100)

    model = model.to(device)

    opt_class = LDoG if args.ldog else DoG
    # Creating the optimizer
    optimizer = opt_class(model.parameters(), reps_rel=args.reps_rel, lr=args.lr,
                          init_eta=(args.init_eta if args.init_eta > 0 else None), weight_decay=args.weight_decay)
    averager = PolynomialDecayAverager(model, gamma=args.avg_gamma)  # Creating the averager

    for epoch in range(1, args.epochs + 1):
        tx.saved_vars['epoch'] = epoch
        tx.start_timer('training', epoch)
        train(args, model, averager, device, train_loader, optimizer, epoch)
        test(model, device, test_loader, 'base model')  # get test results for the base model
        test(averager.averaged_model, device, test_loader, 'averaged model')  # get test results for the averaged model
        tx.stop_timer('training', epoch)

    if args.save_model:
        torch.save(model.state_dict(), "cifar100_resnet18.pt")
        torch.save(averager.averaged_model.state_dict(), "cifar100_resnet18_averaged.pt")


if __name__ == '__main__':
    main()
