import argparse
import logging
import os
import random
import sys
import time
from pprint import pformat

import numpy
import numpy as np
import sklearn.metrics as sklm
import torch
from torch.optim.lr_scheduler import _LRScheduler
from tqdm import tqdm

from src.dataset.factory import create_dataset, MAP_DATASET_TO_ENUM, SupportedDataset
from src.loss.utils import create_loss
from src.models.utils import get_model, append_linear_layer_transform
from src.optimizer.utils import create_optimizer, create_scheduler
from src.utils import logger
from src.utils.sysutils import is_debug_mode, get_cores_count
from src.utils.tensorboard_writer import initialize_tensorboard
from src.utils.utils import make_results_dir

parser = argparse.ArgumentParser()

parser.add_argument('--dataset', type=str, default='CIFAR10', help='The dataset to choose',
                    choices=['CIFAR10', 'Food101', 'NihCxr', 'Birdsnap', 'ImageNet'])
parser.add_argument('--output_dir', type=str, required=False, default='./logs_birdsnap/', help='Output directory path')
parser.add_argument('--model', type=str, default='Resnet8', help='The model to use',
                    choices=['Resnet8', 'Resnet8v2', 'Resnet10', 'Resnet18', 'Resnet34', 'Resnet50', 'ConvNetSimple'])
opt = parser.parse_args()


def log_auc_metrics_tensorboard(tb_writer, aucs, classes, global_step):
    for index, class_name in enumerate(classes):
        tb_writer.save_scalar('_'.join(['aucs', class_name]), aucs[index], global_step)


def log_ap_metrics_tensorboard(tb_writer, aps, classes, global_step):
    for index, class_name in enumerate(classes):
        tb_writer.save_scalar('_'.join(['AvgPrec', class_name]), aps[index], global_step)


def train_and_evaluate_model(arguments):
    """
    Main Pipeline for training and cross-validation.
    """

    """ Setup result directory and enable logging to file in it """
    outdir = make_results_dir(arguments['outdir'])
    logger.init(outdir, logging.INFO)
    logger.info('Arguments:\n{}'.format(pformat(arguments)))

    """ Set random seed throughout python"""
    logger.info('Using Random Seed value as: %d' % arguments['random_seed'])
    torch.manual_seed(arguments['random_seed'])  # Set for pytorch, used for cuda as well.
    random.seed(arguments['random_seed'])  # Set for python
    np.random.seed(arguments['random_seed'])  # Set for numpy

    """ Create tensorboard writer """
    tb_writer = initialize_tensorboard(outdir)

    """ Set device - cpu or gpu """
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    logger.info(f'Using device - {device}')

    """ Load Model with weights(if available) """
    model: torch.nn.Module = get_model(arguments.get('model_args'), device, arguments['dataset_args']).to(device)

    """ Create optimizer and scheduler """
    optimizer = create_optimizer(model.parameters(), arguments['optimizer_args'])
    lr_scheduler: _LRScheduler = create_scheduler(optimizer, arguments['scheduler_args'])

    """ Load parameters for the Dataset """
    dataset = create_dataset(arguments['dataset_args'],
                             arguments['train_data_args'],
                             arguments['val_data_args'])

    """ Create loss function """
    logger.info(f"Loss weights {dataset.pos_neg_balance_weights()}")
    criterion = create_loss(arguments['loss_args'], dataset.pos_neg_balance_weights().to(device))

    """ Sample and View the inputs to model """
    dataset.debug()

    """ Pipeline - loop over the dataset multiple times """
    min_validation_loss, best_validation_model_path = sys.float_info.max, None
    mode = arguments['mode']

    # ToDo  Use Loop for train and val phases - Much cleaner code this way.
    #  https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html#finetuning-the-convnet
    # Separate Training code of Single Classification and Multi-Label Classification and put in separate Trainer files.
    batch_index = 0
    for epoch in range(arguments['nb_epochs']):
        train_data_args = arguments['train_data_args']

        """ Train the model """
        logger.info(f"Training, Epoch {epoch + 1}/{arguments['nb_epochs']}")

        if train_data_args['to_train']:
            check_dataset_corrupt = False
            if check_dataset_corrupt:
                train_dataloader = dataset.train_dataloader
                for i, data in enumerate(tqdm(train_dataloader)):
                    # get the inputs
                    inputs, labels = data

            train_dataloader = dataset.train_dataloader
            model.train()
            start = time.time()
            total, correct = 0, 0
            epoch_loss = 0
            for i, data in enumerate(tqdm(train_dataloader)):
                # get the inputs
                inputs, labels = data
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # Forward Pass
                outputs = model(inputs)

                if train_data_args['weighted_cross_entropy_batchwise']:
                    criterion = create_loss(arguments['loss_args'],
                                            dataset.pos_neg_balance_weights_in_batch(labels).to(device))

                loss = criterion(outputs, labels)
                loss.backward()

                tb_writer.save_scalar('batch_training_loss', loss.item(), batch_index)
                batch_index+=1
                epoch_loss += loss.item() * labels.size(0)
                total += labels.size(0)
                if mode == 'classification':
                    _, predicted = torch.max(outputs.data, 1)
                    correct += (predicted == labels).sum().item()

                optimizer.step()
                if is_debug_mode():
                    break

            epoch_loss = epoch_loss / total
            logger.info(f"Epoch = {epoch}, Train_loss = {epoch_loss}, "
                        f"Time taken = {time.time() - start} seconds.")
            if arguments['mode'] == 'classification':
                logger.info(f"Train_accuracy = {100 * correct / total}")
            tb_writer.save_scalar('training_loss', epoch_loss, epoch)
            tb_writer.save_scalar('training_acc', 100 * correct / total, epoch)

        """ Validate the model """
        val_data_args = arguments['val_data_args']
        if val_data_args['validate_step_size'] > 0 and \
                epoch % val_data_args['validate_step_size'] == 0:

            model.eval()
            validation_dataloader = dataset.validation_dataloader
            logger.info(f"Validation, Epoch {epoch + 1}/{arguments['nb_epochs']}")

            if mode == 'classification':
                val_loss, val_accuracy = evaluate_single_class(device, model, validation_dataloader, criterion)
                logger.info(f'validation images: {dataset.val_dataset_size}, '
                            f'val_auc : {val_accuracy} %% '
                            f'val_loss: {val_loss}')
                tb_writer.save_scalar('validation_acc', val_accuracy, epoch)

            elif mode == 'multilabel_classification':
                val_loss, val_auc, val_ap = evaluate_multi_class(device, model, validation_dataloader, criterion)
                logger.info(f'AUC of the network on the {dataset.val_dataset_size} validation images: '
                            f'{dict(zip(dataset.classes, val_auc))}')
                logger.info(f'Average Precision on the {dataset.val_dataset_size} validation images : '
                            f'{dict(zip(dataset.classes, val_ap))}')
                log_auc_metrics_tensorboard(tb_writer, val_auc, dataset.classes, epoch)
                log_ap_metrics_tensorboard(tb_writer, val_ap, dataset.classes, epoch)
                best_validation_model_path = os.path.join(outdir,
                                                   f'epoch_{epoch:04}-model-val_meanAUC_{numpy.mean(val_auc)}.pth')
            tb_writer.save_scalar('validation_loss', val_loss, epoch)

            """ Save Model """
            if val_loss < min_validation_loss:
                min_validation_loss = val_loss
                if best_validation_model_path:
                    os.remove(best_validation_model_path)
                best_validation_model_path = os.path.join(outdir,
                                                   f'epoch_{epoch:04}-model-val_acc_{val_accuracy}.pth')
                torch.save(model.state_dict(), best_validation_model_path)
                logger.info(f'Model saved at: {best_validation_model_path}')

        if lr_scheduler:
            prev_lr = lr_scheduler.get_lr()
            lr_scheduler.step(epoch)
            if lr_scheduler.get_lr() != prev_lr:
                logger.warn(f'Updated LR from {prev_lr} to {lr_scheduler.get_lr()}')

        # Exit loop if training not needed
        if not train_data_args['to_train']:
            break

    logger.info('Finished Training')
    logger.info(f'Best Model saved at: {best_validation_model_path}')

    # Evaluate model on test set
    model.load_state_dict(torch.load(best_validation_model_path), strict=False)
    test_dataloader = dataset.test_dataloader
    if mode == 'classification':
        test_loss, test_accuracy = evaluate_single_class(device, model, test_dataloader, criterion)
        logger.info(f'Accuracy of the network on the {dataset.test_dataset_size} test images: {test_accuracy} %%')
        return test_loss, test_accuracy
    else:
        test_loss, test_auc, test_ap = evaluate_multi_class(device, model, test_dataloader, criterion)
        logger.info(f'AUC of the network on the {dataset.test_dataset_size} test images: '
                    f'{dict(zip(dataset.classes, test_auc))}')
        logger.info(f'Average Precision on the {dataset.test_dataset_size} test images : '
                    f'{dict(zip(dataset.classes, test_ap))}')
        return test_loss, test_auc, test_ap


def evaluate_single_class(device, model, dataloader, criterion):
    correct, total_samples = 0, 0
    total_loss = 0
    model.eval()
    with torch.no_grad():
        for i, data in enumerate(tqdm(dataloader)):
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            total_loss += loss.item() * labels.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total_samples += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total_samples
    total_loss /= total_samples
    return total_loss, accuracy


def evaluate_multi_class(device, model, dataloader, criterion):
    predictions, actual = [], []
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for i, data in enumerate(tqdm(dataloader)):
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            total_loss += loss.item() * labels.size(0)

            outputs = torch.sigmoid(outputs)
            predictions.append(outputs.cpu().numpy())
            actual.append(labels.cpu().numpy())
            if is_debug_mode():
                break

    actual = numpy.concatenate(actual, axis=0)
    predictions = numpy.concatenate(predictions, axis=0)
    num_classes = labels.shape[1]
    aucs, aps = [], []
    for class_index in range(num_classes):
        if numpy.count_nonzero(actual[:, class_index].reshape(-1, 1)) != 0:
            auc = sklm.roc_auc_score(actual[:, class_index].reshape(-1, 1),
                                     predictions[:, class_index].reshape(-1, 1))
            ap = sklm.average_precision_score(actual[:, class_index].reshape(-1, 1),
                                              predictions[:, class_index].reshape(-1, 1))
        else:
            auc = ap = 0

        aucs.append(auc)
        aps.append(ap)
    return total_loss, aucs, aps


def main():
    dataset_name = opt.dataset

    mode = 'classification' if opt.dataset != 'NihCxr' else 'multilabel_classification'

    model_config = dict(
        ConvNetSimple=dict(
            model_arch_name='src.models.classification.ConvNetSimple.ConvNetSimple',
            model_weights_path=None,
            model_constructor_args=dict(
                input_size=SupportedDataset.CIFAR10_Enum.value['image_size'],
                number_of_input_channels=SupportedDataset.CIFAR10_Enum.value[
                    'channels'],
                number_of_classes=SupportedDataset.CIFAR10_Enum.value['labels_count'],
            )),
        Resnet8=dict(
            model_arch_name='src.models.classification.PytorchCifarResnet.ResNet8',
            model_weights_path=None,
            model_constructor_args=dict()),
        Resnet8v2=dict(
            model_arch_name='src.models.classification.PytorchCifarResnet.ResNet8v2',
            model_weights_path=None,
            model_constructor_args=dict()),
        Resnet10=dict(
            model_arch_name='src.models.classification.PytorchCifarResnet.ResNet10',
            model_weights_path=None,
            model_constructor_args=dict()),
        Resnet18=dict(
            model_arch_name='torchvision.models.resnet18',
            model_weights_path=None,
            model_constructor_args=dict(
                pretrained=True
            ),
            model_transformer=append_linear_layer_transform),
        Resnet34=dict(
            model_arch_name='torchvision.models.resnet34',
            model_weights_path=None,
            model_constructor_args=dict(
                pretrained=True
            ),
            model_transformer=append_linear_layer_transform),
        Resnet50=dict(
            model_arch_name='torchvision.models.resnet50',
            model_weights_path=None,
            model_constructor_args=dict(
                pretrained=True
            ),
            model_transformer=append_linear_layer_transform),
    )

    # Specific training configs for different dataset
    dataset_configs = dict(CIFAR10=dict(batch_size=256,
                                        model_args=dict(
                                            model_config[opt.model]
                                        ),
                                        optimizer_args=dict(
                                            name='torch.optim.Adam',
                                            lr=1e-3
                                        ),
                                        nb_epochs=25
                                        ),
                           ImageNet=dict(batch_size=128,
                                         model_args=dict(
                                             model_config[opt.model]
                                         ),
                                         optimizer_args=dict(
                                             name='torch.optim.Adam',
                                             lr=1e-3
                                         ),
                                         scheduler_args=dict(
                                             name='torch.optim.lr_scheduler.StepLR',
                                             step_size=10,
                                             gamma=0.5
                                         ),
                                         nb_epochs=50
                                         ),
                           Food101=dict(batch_size=128,
                                        model_args=dict(
                                            model_config[opt.model]
                                        ),
                                        optimizer_args=dict(
                                            name='torch.optim.Adam',
                                            lr=1e-4,
                                            weight_decay=1e-4,
                                        ),
                                        nb_epochs=20,
                                        lr_strategy=(10, 0.1)  # decrease lr to 0.1*lr every 10 epochs
                                        ),
                           NihCxr=dict(batch_size=64,
                                       model_args=dict(
                                           model_config[opt.model]
                                       ),
                                       optimizer_args=dict(
                                           name='torch.optim.Adam',
                                           lr=1e-3,
                                           weight_decay=1e-4,
                                       ),
                                       scheduler_args=dict(
                                           name='torch.optim.lr_scheduler.StepLR',
                                           step_size=10,
                                           gamma=0.5
                                       ),
                                       nb_epochs=50
                                       ),
                           Birdsnap=dict(batch_size=64,
                                         model_args=dict(
                                             model_config[opt.model]
                                         ),
                                         optimizer_args=dict(
                                             name='torch.optim.Adam',
                                             lr=0.0005 if opt.model == 'Resnet50' else 0.001,
                                             weight_decay=0 if opt.model == 'Resnet50' else 0.001,  # Reg for Res50 hasnt been tried.
                                         ),
                                         scheduler_args=dict(
                                             name='torch.optim.lr_scheduler.StepLR',
                                             step_size=10,
                                             gamma=0.5
                                         ),
                                         nb_epochs=25,
                                         split_ratio=[0.8, 0.1, 0.1]
                                         )
                           )

    # Common Configuration
    dataset_args = dict(
        name=MAP_DATASET_TO_ENUM[dataset_name],
        split_ratio=dataset_configs[dataset_name]['split_ratio'] if dataset_configs[dataset_name].get('split_ratio')
        else 7.0 / 8,
    )

    train_data_args = dict(
        batch_size=dataset_configs[dataset_name]['batch_size'],
        shuffle=True,
        to_train=True,
        weighted_cross_entropy_batchwise=False,
    )

    val_data_args = dict(
        batch_size=train_data_args['batch_size'] * 4,
        shuffle=False,
        validate_step_size=1,
    )

    loss_args = dict(
        name='torch.nn.CrossEntropyLoss' if mode == 'classification' else 'torch.nn.BCEWithLogitsLoss'
    )

    arguments = dict(
        mode=mode,
        dataset_args=dataset_args,
        train_data_args=train_data_args,
        val_data_args=val_data_args,
        model_args=dataset_configs[dataset_name]['model_args'],
        loss_args=loss_args,
        optimizer_args=dataset_configs[dataset_name]['optimizer_args'],
        scheduler_args=dataset_configs[dataset_name]['scheduler_args'],
        outdir=opt.output_dir,
        nb_epochs=dataset_configs[dataset_name]['nb_epochs'],
        random_seed=random.randint(0, 1000)
    )

    train_and_evaluate_model(arguments)


if __name__ == '__main__':
    main()
