import argparse
import datetime
import itertools
import os
import random
import time
from pprint import pformat

import numpy as np
import texttable
import torch
from tqdm import tqdm

from src.dataset.factory import SupportedDataset
from src.loss.utils import create_loss
from src.models.utils import get_model
from src.optimizer.utils import create_optimizer
from src.roar.compound_cifar_dataset import CompoundCifarDataset

parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='CIFAR10', help='The dataset to choose', choices=['CIFAR10',
                                                                                                     'Food101',
                                                                                                     'Birdsnap'])
parser.add_argument('--output_dir', type=str, required=False, default='./logs/', help='Output directory path')
parser.add_argument('--model', type=str, default='Resnet8', help='The model to use',
                    choices=['Resnet8', 'Resnet8v2', 'Resnet10', 'Resnet18', 'Resnet50', 'ConvNetSimple'])
opt = parser.parse_args()


def train_and_evaluate_model(arguments):
    """
    Main Pipeline for training and cross-validation.
    """

    """ Setup result directory and enable logging to file in it """
    outdir = arguments.get("outdir")
    print('Arguments:\n{}'.format(pformat(arguments)))

    """ Set random seed throughout python"""
    print('Using Random Seed value as: %d' % arguments['random_seed'])
    torch.manual_seed(arguments['random_seed'])  # Set for pytorch, used for cuda as well.
    random.seed(arguments['random_seed'])  # Set for python
    np.random.seed(arguments['random_seed'])  # Set for numpy

    """ Set device - cpu or gpu """
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(f'Using device - {device}')

    """ Load Model with weights(if available) """
    model: torch.nn.Module = get_model(arguments.get('model_args'), device,
                                       dict(labels_count=len(arguments['dataset'].classes))).to(device)

    """ Create loss function """
    criterion = create_loss(arguments['loss_args'])

    """ Create optimizer """
    lr = arguments['optimizer_args']['lr']
    optimizer = create_optimizer(model.parameters(), arguments['optimizer_args'])

    """ Load Compound Dataset """
    dataset = arguments['dataset']

    """ Pipeline - loop over the dataset multiple times """
    max_validation_accuracy, max_validation_path = 0, None

    # Only create dataloaders once
    train_dataloader = dataset.get_train_dataloader(arguments['train_data_args'])
    validation_dataloader = dataset.get_validation_dataloader(arguments['val_data_args'])

    for epoch in range(arguments['nb_epochs']):
        """ Train the model """
        print(f"Training, Epoch {epoch + 1}/{arguments['nb_epochs']}")
        train_data_args = arguments['train_data_args']

        """Schedule learning rate if required """
        if train_data_args.get('lr_strategy'):  # ToDo - Move lr_strategy to optimizer_args
            lr_strategy = train_data_args.get('lr_strategy')
            if epoch > 0 and epoch % lr_strategy[0] == 0:
                def scheduler(optimizer, lr):
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = lr
                    return optimizer

                new_lr = lr * lr_strategy[1]
                print(f'Changing LR from {lr} to {new_lr}')
                scheduler(optimizer, lr=new_lr)
                lr = new_lr

        if train_data_args['to_train']:
            model.train()
            start = time.time()
            total, correct = 0, 0
            for i, data in enumerate(tqdm(train_dataloader)):
                # get the inputs
                inputs, labels = data
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # Forward Pass
                outputs = model(inputs)

                total_loss = criterion(outputs, labels)
                total_loss.backward()
                optimizer.step()

                total += labels.size(0)
                _, predicted = torch.max(outputs.data, 1)
                correct += (predicted == labels).sum().item()

            train_accuracy = 100 * correct / total
            print(f"Epoch-{epoch}, Train_accuracy-{train_accuracy}, Time taken = {time.time() - start} seconds.")

        """ Cross-Validate the model """
        val_data_args = arguments['val_data_args']
        if val_data_args['validate_step_size'] > 0 and \
                epoch % val_data_args['validate_step_size'] == 0:

            print(f"Validation, Epoch {epoch + 1}/{arguments['nb_epochs']}")
            val_accuracy = evaluate(device, model, validation_dataloader)
            print(f'Accuracy of the network on the {dataset.val_dataset_size} validation images: {val_accuracy} %%')

            """ Save Model """
            if val_accuracy > max_validation_accuracy:
                attempt, eval_metric, attribution_method, percentile = arguments['model_name_args']
                max_validation_path = os.path.join(outdir,
                                                   f'run_{attempt}-'
                                                   f'attr_{attribution_method}-'
                                                   f'percentile_{percentile}'
                                                   f'epoch_{epoch:04}-'
                                                   f'model-{arguments["model_args"]["model_arch_name"].split(".")[-1]}-'
                                                   f'val_{val_accuracy}.pth')
                torch.save(model.state_dict(), max_validation_path)
                max_validation_accuracy = val_accuracy
                print(f'Model saved at: {max_validation_path}')

        # Exit loop if training not needed
        if not train_data_args['to_train']:
            break

    print('Finished Training')
    print(f'Max Validation accuracy is {max_validation_accuracy}')
    print(f'Best Model saved at: {max_validation_path}')

    # Evaluate model on test set
    model.load_state_dict(torch.load(max_validation_path), strict=False)
    test_dataloader = dataset.get_test_dataloader(arguments['val_data_args'])  # batch size for val/test is same
    test_accuracy = evaluate(device, model, test_dataloader)
    return test_accuracy


def evaluate(device, model, dataloader):
    correct, total = 0, 0
    model.eval()
    with torch.no_grad():
        for i, data in enumerate(tqdm(dataloader)):
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    return accuracy


def main():
    outdir = 'roar_kar_experiments_abs_no_abs'
    datasets_name = [opt.dataset]
    eval_metrics = ['roar']  # ['roar', 'kar']
    percentiles = [10, 30, 50, 70, 90]
    attribution_methods = ['VanillaSaliency', 'input_x_grad', 'gbp', 'gradcam', 'rectgard', 'integrad', 'sparsity0_prune_pgd']
    prune_methods = ['prune_grad_abs', 'prune_pgd_abs', 'PruneGrad-Mid']
    if opt.dataset == 'CIFAR10':
        if opt.model == 'Resnet8':
            prune_thresholds = [92]

    attribution_methods.extend(prune_method + '_' + str(prune_threshold)
                               for prune_method, prune_threshold in itertools.product(prune_methods, prune_thresholds))

    attempts = [i for i in range(3)]
    print('Attribution methods', attribution_methods)

    # debug samples
    debug = False
    if debug:
        indices = 6
        train_debug_samples = random.sample(range(1, 45000), indices)
        test_debug_samples = random.sample(range(1, 10000), indices)
        print(f'train_debug_samples: {train_debug_samples}')
        print(f'test_debug_samples: {test_debug_samples}')
        timestamp = datetime.datetime.now().isoformat()
        debug_outdir = os.path.join(outdir, timestamp + '_roar_retrain_debug_samples')
        os.makedirs(debug_outdir, exist_ok=False)
        print(f'Output Dir: {debug_outdir}')

    model_config = dict(
        ConvNetSimple=dict(
            model_arch_name='src.models.classification.ConvNetSimple.ConvNetSimple',
            model_weights_path=None,
            model_constructor_args=dict(
                input_size=SupportedDataset.CIFAR10_Enum.value['image_size'],
                number_of_input_channels=SupportedDataset.CIFAR10_Enum.value[
                    'channels'],
                number_of_classes=SupportedDataset.CIFAR10_Enum.value['labels_count'],
            )),
        Resnet8=dict(
            model_arch_name='src.models.classification.PytorchCifarResnet.ResNet8',
            model_weights_path=None,
            model_constructor_args=dict()),
        Resnet8v2=dict(
            model_arch_name='src.models.classification.PytorchCifarResnet.ResNet8v2',
            model_weights_path=None,
            model_constructor_args=dict()),
        Resnet10=dict(
            model_arch_name='src.models.classification.PytorchCifarResnet.ResNet10',
            model_weights_path=None,
            model_constructor_args=dict()),
        Resnet18=dict(
            model_arch_name='torchvision.models.resnet18',
            model_weights_path=None,
            model_constructor_args=dict(
                pretrained=True
            )),
    )

    # Training Configs for different dataset
    training_configs = dict(CIFAR10=dict(batch_size=256,
                                         model_args=dict(
                                             model_config[opt.model]
                                         ),
                                         optimizer_args=dict(
                                             name='torch.optim.Adam',
                                             lr=1e-3
                                         ),
                                         nb_epochs=25
                                         ),
                            Food101=dict(batch_size=128,
                                         model_args=dict(
                                             model_config[opt.model]
                                         ),
                                         optimizer_args=dict(
                                             name='torch.optim.Adam',
                                             lr=1e-4,
                                             weight_decay=1e-4,
                                         ),
                                         nb_epochs=20,
                                         lr_strategy=(10, 0.1)  # decrease lr to 0.1*lr every 10 epochs
                                         )
                            )

    tt = texttable.Texttable()
    tt.header(['attempt', 'dataset', 'eval_metric', 'attribution_method', 'percentile', 'test-accuracy'])

    # For each dataset, for each percentile and for each
    timestamp = datetime.datetime.now().isoformat()
    for eval_metric, attempt, dataset_name, percentile, attribution_method in itertools.product(eval_metrics,
                                                                                                attempts,
                                                                                                datasets_name,
                                                                                                percentiles,
                                                                                                attribution_methods):
        # Output Directory
        if not debug:
            models_savedir = os.path.join(outdir,
                                          timestamp + '_' + dataset_name + '_' + eval_metric)  # For saving results
            os.makedirs(models_savedir, exist_ok=True)

        # Create Compound Dataset from Image Dataset and attribution image dataset
        training_config = training_configs[dataset_name]
        dataset_modes = ['train', 'test']
        attribution_paths = [os.path.join(outdir,
                                          f'{opt.dataset}_{dataset_mode}/'  # E.g. roar_kar_experiments/CIFAR10_train/modelname_attrname/ 
                                          f'{training_config["model_args"]["model_arch_name"].split(".")[-1]}'
                                          f'_{attribution_method}') for dataset_mode in dataset_modes]
        compound_dataset = CompoundCifarDataset(dataset_name,
                                                attribution_paths[0],
                                                attribution_paths[1],
                                                roar=eval_metric == 'roar',
                                                percentile=percentile)

        # View some samples of dataset
        if debug:
            if attempt == 0:
                print(attribution_method, percentile)
                compound_dataset.debug(outdir=debug_outdir,
                                       name=f'{opt.dataset}_{dataset_modes[1]}_{percentile}_{attribution_method}',
                                       train=True,
                                       indices=test_debug_samples)
                compound_dataset.debug(outdir=debug_outdir,
                                       name=f'{opt.dataset}_{dataset_modes[1]}_{percentile}_{attribution_method}',
                                       train=False,
                                       indices=test_debug_samples)
            continue  # No need to do any training

        # Retrain Model on Compound Dataset and Evaluate
        train_data_args = dict(
            batch_size=training_configs[dataset_name]['batch_size'],
            shuffle=True,
            to_train=True,
        )

        val_data_args = dict(
            batch_size=train_data_args['batch_size'] * 4,
            shuffle=False,
            validate_step_size=1,
        )

        loss_args = dict(
            name='torch.nn.CrossEntropyLoss'
        )

        # Always retrain from start
        assert training_configs[dataset_name]['model_args']['model_weights_path'] is None

        arguments = dict(
            dataset=compound_dataset,
            model_name_args=(attempt + 1, eval_metric, attribution_method, percentile),
            train_data_args=train_data_args,
            val_data_args=val_data_args,
            model_args=training_configs[dataset_name]['model_args'],
            loss_args=loss_args,
            optimizer_args=training_configs[dataset_name]['optimizer_args'],
            outdir=models_savedir,
            nb_epochs=training_configs[dataset_name]['nb_epochs'],
            random_seed=attempt  # To get proper stats on multiple runs, dont use fixed random seed.
        )

        final_roar_test_accuracy = train_and_evaluate_model(arguments)
        tt.add_row([attempt + 1, dataset_name, eval_metric, attribution_method, percentile, final_roar_test_accuracy])
        print(tt.draw())


if __name__ == '__main__':
    main()
