import argparse
import logging
import os
import time
import pandas as pd
import torch
from torchinfo import summary # type:ignore
from scratch_models import *
from torchvision import models
from torchmetrics.classification import Accuracy, F1Score #type:ignore
from tqdm import tqdm
from utils import *

# Setup logging for other files to use
# if os.path.exists('./logs/train.log'):
#     os.remove('./logs/train.log')
addLoggingLevel('TRACE', logging.DEBUG - 5)
global_config_logger(log_file = './logs/train.log', log_level = logging.DEBUG)

from datasets import *
from datamodule import *


if __name__ == '__main__':
    # Setup the main logger
    logger = setup_logger(name = __name__)

    # Argument Parser init
    parser = argparse.ArgumentParser(description = 'Train a model')

    # Dataset flags
    parser.add_argument('--imagenette', action = 'store_true', help = 'Imagenette Dataset')
    parser.add_argument('--imagenet100', action = 'store_true', help = 'First 100 classes of Imagenet')
    parser.add_argument('--imagenet', action = 'store_true', help = 'Imagenet Dataset')
    parser.add_argument('--imagewoof', action = 'store_true', help = 'Imagewoof Dataset')
    parser.add_argument('--cifar10', action = 'store_true', help = 'Cifar10 Dataset')
    parser.add_argument('--fashionmnist', action = 'store_true', help = 'FashionMNIST Dataset')
    parser.add_argument('--mnist', action = 'store_true', help = 'MNIST Dataset')
    parser.add_argument('--svhn', action = 'store_true', help = 'SVHN Dataset')

    # Model hyperparameters
    parser.add_argument('--epochs', required = False, default = 10, type = int, help = 'Number of training loops')
    parser.add_argument('--lr', required = False, default = 0.001, type = float, help = 'Learning rate of the model')
    parser.add_argument('--batch_size', required = False, default = 32, type = int, help = 'Batch size of data')

    # Hardware parameters
    parser.add_argument('--gpu', action = 'store_true', help = 'GPU Flag')
    parser.add_argument('--num_workers', required = False, default = 1, type = int, help = 'Number of workers for dataloader')

    # Directory parameters
    parser.add_argument('--log_dir', required = False, default = './logs', type = str, help = 'Directory to store logs')
    parser.add_argument('--checkpoint_dir', required = False, default = './checkpoints', type = str, help = 'Directory to store model checkpoints')

    # Class removal parameters
    parser.add_argument('--remove_classes', action = 'store_true', help = 'Removed frozen classes from the dataset')
    parser.add_argument('--f1_threshold', required = False, default = 0.85, type = float, help = 'Threshold score to remove class')
    parser.add_argument('--residue', required = False, default = 0.00, type = float, help = 'Percentage of class to leave behind')
    parser.add_argument('--reset_residue', required = False, default = 0, type = int, help = 'Determine interval of when to reset the residue')
    parser.add_argument('--warmup_period', required = False, default = 0, type = int, help = 'Number of epochs to warm-up')
    parser.add_argument('--removal_restriction', required = False, default = 0, type = int, help = 'Restriction policy on classes to remove')
    
    # Other parameters
    parser.add_argument('--metric_file_name', required = False, default = 'metrics.csv', type = str, help = 'Name of metric file')
    parser.add_argument('--fresh_start', action = 'store_true', help = 'Prevent checkpoint loading.')
    parser.add_argument('--checkpoint_file', type = str, default = 'baseline.pth', required = False, help = 'Name of checkpoint file to save to.')
    
    # Set defaults
    parser.set_defaults(fresh_start = False)
    parser.set_defaults(remove_classes = False)

    # Parse the arguments
    args = parser.parse_args()

    # Create the checkpoint and log directories if they are missing
    if not os.path.exists(args.log_dir):
        os.mkdir(args.log_dir)

    if not os.path.exists(args.checkpoint_dir):
        os.mkdir(args.checkpoint_dir)

    logger.info('Running {} experiment'.format(args.metric_file_name))

    # Determine which dataset to use
    if args.imagenette:
        dataset_name = 'Imagenette'
        logger.info('Selected {} dataset\n'.format(dataset_name))
        training_dataset = Imagenette(image_dir = '/data/progressive_data_dropout/imagenette/train')
        logger.info('Training set has {} instances'.format(len(training_dataset)))
        validation_dataset = Imagenette(image_dir = '/data/progressive_data_dropout/imagenette/val')
        logger.info('Validation set has {} instances'.format(len(validation_dataset)))
        num_classes = 10
        model = models.resnet18(pretrained = False)
        model.fc = nn.Linear(512, num_classes)
    elif args.imagenet100:
        dataset_name = 'Imagenet100'
        logger.info('Selected {} dataset\n'.format(dataset_name))
        training_dataset = Imagenet(image_dir = '/data/progressive_data_dropout/imagenet/train')
        logger.info('Training set has {} instances'.format(len(training_dataset)))
        validation_dataset = Imagenet(image_dir = '/data/progressive_data_dropout/imagenet/val')
        logger.info('Validation set has {} instances'.format(len(validation_dataset)))
        num_classes = 100
        model = models.resnet18(pretrained = False, num_classes = num_classes)
    elif args.imagenet:
        dataset_name = 'Imagenet'
        logger.info('Selected {} dataset\n'.format(dataset_name))
        training_dataset = Imagenet(image_dir = '/data/progressive_data_dropout/imagenet/train')
        logger.info('Training set has {} instances'.format(len(training_dataset)))
        validation_dataset = Imagenet(image_dir = '/data/progressive_data_dropout/imagenet/val')
        logger.info('Validation set has {} instances'.format(len(validation_dataset)))
        num_classes = 1000
        model = models.resnet18(pretrained = False, num_classes = num_classes)
    elif args.imagewoof:
        dataset_name = 'Imagewoof'
        logger.info('Selected {} dataset\n'.format(dataset_name))
        training_dataset = Imagenet(image_dir = '/data/progressive_data_dropout/imagewoof/train')
        logger.info('Training set has {} instances'.format(len(training_dataset)))
        validation_dataset = Imagenet(image_dir = '/data/progressive_data_dropout/imagewoof/val')
        logger.info('Validation set has {} instances'.format(len(validation_dataset)))
        num_classes = 10
        model = models.resnet18(pretrained = False, num_classes = num_classes)
    elif args.cifar10:
        dataset_name = 'Cifar10'
        logger.info('Selected {} dataset\n'.format(dataset_name))
        full_dataset = CustomCIFAR10(image_dir = '/data/progressive_data_dropout/cifar10')
        full_dataset.setup()
        training_dataset = full_dataset.train
        logger.info('Training set has {} instances'.format(len(training_dataset)))
        validation_dataset = full_dataset.val
        logger.info('Validation set has {} instances'.format(len(validation_dataset)))
        num_classes = 10
        model = models.resnet18(pretrained = False, num_classes = num_classes)
    elif args.fashionmnist:
        dataset_name = 'FashionMNIST'
        logger.info('Selected {} dataset\n'.format(dataset_name))
        full_dataset = CustomFashionMNIST(image_dir = '/data/progressive_data_dropout/fashionmnist')
        full_dataset.setup()
        training_dataset = full_dataset.train
        logger.info('Training set has {} instances'.format(len(training_dataset)))
        validation_dataset = full_dataset.val
        logger.info('Validation set has {} instances'.format(len(validation_dataset)))
        num_classes = 10
        model = models.resnet18(pretrained = False, num_classes = num_classes)
    elif args.mnist:
        dataset_name = 'MNIST'
        logger.info('Selected {} dataset\n'.format(dataset_name))
        full_dataset = CustomMNIST(image_dir = '/data/progressive_data_dropout/mnist')
        full_dataset.setup()
        training_dataset = full_dataset.train
        logger.info('Training set has {} instances'.format(len(training_dataset)))
        validation_dataset = full_dataset.val
        logger.info('Validation set has {} instances'.format(len(validation_dataset)))
        num_classes = 10
        model = models.resnet18(pretrained = False, num_classes = num_classes)
    elif args.svhn:
        dataset_name = 'SVHN'
        logger.info('Selected {} dataset\n'.format(dataset_name))
        full_dataset = CustomSVHN(image_dir = '/data/progressive_data_dropout/svhn')
        full_dataset.setup()
        training_dataset = full_dataset.train
        logger.info('Training set has {} instances'.format(len(training_dataset)))
        validation_dataset = full_dataset.val
        logger.info('Validation set has {} instances'.format(len(validation_dataset)))
        num_classes = 10
        model = models.resnet18(pretrained = False, num_classes = num_classes)
    else:
        raise NotImplementedError('Please select a dataset flag')

    # Get a summary of the model
    model_summary = str(summary(model, (args.batch_size, 3, 224, 224), verbose = 0))
    logger.info('Model Summary\n{}'.format(model_summary))

    # Create datamodules
    logger.debug('Creating Datamodules')
    training_datamodule = CustomDataModule(dataset = training_dataset, batch_size = args.batch_size, num_workers = args.num_workers)
    validation_datamodule = CustomDataModule(dataset = validation_dataset, batch_size = args.batch_size, num_workers = args.num_workers)
    
    # Check for imagenet 100
    if args.imagenet100:
        logger.debug('Removing classes 100 - 999 from Imagenet')
        for class_label in range(100, 1000):
            training_datamodule.remove_class_with_residue(class_label = class_label, residue = 0.00)
            validation_datamodule.remove_class_with_residue(class_label = class_label, residue = 0.00)

    # Create the dataloaders
    logger.debug('Creating Dataloaders')
    training_dataloader = training_datamodule.dataloader(shuffle = True)
    validation_dataloader = validation_datamodule.dataloader(shuffle = False)

    logger.info('Training set has {} batches'.format(len(training_dataloader)))
    logger.info('Validation set has {} batches\n'.format(len(validation_dataloader)))

    # Loss function and optimizer
    loss_fn = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr = args.lr)
    logger.info('Loss Function: {}'.format(loss_fn))
    logger.info('Optimizer: {}\n'.format(optimizer))

    # Metric implementation
    training_accuracy = Accuracy()
    validation_accuracy = Accuracy()
    training_f1score = F1Score()
    validation_f1score = F1Score()
    per_class_training_f1score = F1Score(num_classes = num_classes, average = None)
    per_class_validation_f1score = F1Score(num_classes = num_classes, average = None)

    # Create a metric dictionary and dataframe
    metric_path = os.path.join(args.log_dir, args.metric_file_name)
    metric_dict = {}
    metric_df = pd.DataFrame()

    # GPU flag checking
    if args.gpu and torch.cuda.is_available():
        device = 'cuda:0'
    else:
        device = 'cpu'
    
    removed_classes = []
    processed_classes = []
 
    # Send the model to the correct device
    model = model.to(device)

    # Check checkpointing for loading
    checkpoint_path = os.path.join(args.checkpoint_dir, args.checkpoint_file)
    if os.path.exists(checkpoint_path) and not args.fresh_start:
        checkpoint = torch.load(checkpoint_path)
    else:
        checkpoint = None
        last_epoch = -1

    # Load the model state dictionary if checkpoint is avaliable
    if checkpoint:
        logger.info('Loading checkpoint state dictionary.')
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        last_epoch = checkpoint['epoch']
        logger.info('Last Epoch: {}'.format(last_epoch))

        # Check if metric file exists
        if os.path.exists(metric_path):
            metric_df = pd.read_csv(metric_path)

    # Send the metrics to the correct device
    training_accuracy = training_accuracy.to(device)
    validation_accuracy = validation_accuracy.to(device)
    training_f1score = training_f1score.to(device)
    validation_f1score = validation_f1score.to(device)
    per_class_training_f1score = per_class_training_f1score.to(device)
    per_class_validation_f1score = per_class_validation_f1score.to(device)

    previous_validation_accuracy = 0

    # Training Loop *Maybe move this sections to model file?*
    best_loss = 1e10
    for current_epoch in range(args.epochs):

        time_start = time.perf_counter()

        # Skip epochs that have already happened
        if current_epoch <= last_epoch:
            continue
        
        # ---Training Start---
        current_training_loss_total = 0.0
        
        # Switch the model to training mode
        model.train()

        # Progress bar (Needs reset every epoch)
        progress_bar = tqdm(training_dataloader)

        for current_batch_idx, batch_data in enumerate(progress_bar):
            # Extract the image and labels from batch
            images, labels = batch_data

            # Send the tensors to the correct device
            images = images.to(device)
            labels = labels.to(device)

            if images.size(dim=0) == 1:
                logger.debug('Skipped batch due to only 1 example. ')
                break

            # Zero out gradients
            optimizer.zero_grad()

            # Get the model output for this batch
            outputs = model(images)
            if current_batch_idx < 5:
                logger.debug('Training Batch {} Correct Results: {}'.format(
                    current_batch_idx, torch.sum(torch.eq(labels, torch.argmax(outputs, dim = -1)))))

            # Compute the loss
            training_loss = loss_fn(outputs, labels)

            # Update the training metrics
            training_accuracy.update(outputs, labels)
            training_f1score.update(outputs, labels)
            per_class_training_f1score.update(outputs, labels)

            # Calculate the loss gradients
            training_loss.backward()
                        
            # Adjust the learning weights
            optimizer.step()

            # Keep track of the running loss
            current_training_loss_total += training_loss.item()

            # Average the loss across all the batches
            current_training_loss = current_training_loss_total / (current_batch_idx + 1)

            # Update progress bar
            progress_bar.set_description('Training Loss: {:.2f}'.format(current_training_loss), refresh = True)

        progress_bar.close()
        logger.info(str(progress_bar))

        # Total the training accuracy over all batches
        total_training_accuracy = training_accuracy.compute().item()
        total_training_f1score = training_f1score.compute().item()
        total_per_class_training_f1score = per_class_training_f1score.compute()

        # Record the metrics in the dictionary
        metric_dict['training_loss'] = round(current_training_loss, 2)
        metric_dict['training_accuracy'] = round(total_training_accuracy * 100, 2) 
        metric_dict['training_f1score'] = round(total_training_f1score, 2)
        # ---Training End---

        # ---Validation Start---
        current_validation_loss_total = 0.0
        
        # Switch the model to validation mode
        model.eval()

        # Progress bar (Needs reset every epoch)
        progress_bar = tqdm(validation_dataloader)

        for current_batch_idx, batch_data in enumerate(progress_bar):
            # Extract the image and labels from batch
            images, labels = batch_data

            # Send the tensors to the correct device
            images = images.to(device)
            labels = labels.to(device)

            if images.size(dim=0) == 1:
                print('Skipped batch due to only 1 example. ')
                break

            # Get the model output for this batch
            outputs = model(images)

            # Compute the loss
            validation_loss = loss_fn(outputs, labels)

            # Update the validation metrics
            validation_accuracy.update(outputs, labels)
            validation_f1score.update(outputs, labels)
            per_class_validation_f1score.update(outputs, labels)

            # Keep track of the running loss
            current_validation_loss_total += validation_loss.item()

            # Average the loss across all the batches
            current_validation_loss = current_validation_loss_total / (current_batch_idx + 1)

            # Update progress bar
            progress_bar.set_description('Validation Loss: {:.2f}'.format(current_validation_loss), refresh = True)

        progress_bar.close()
        logger.info(str(progress_bar))

        # Compute the validation metrics
        total_validation_accuracy = validation_accuracy.compute().item()
        total_validation_f1score = validation_f1score.compute().item()
        total_per_class_validation_f1score = per_class_validation_f1score.compute()

        # Record the metrics in the dictionary
        metric_dict['validation_loss'] = round(current_validation_loss, 2)
        metric_dict['validation_accuracy'] = round(total_validation_accuracy, 2)
        metric_dict['validation_f1score'] = round(total_validation_f1score, 2)
        metric_dict['training_datapoints'] = len(training_datamodule.dataset)

        # Record per class train/validation scores
        for class_id, train_f1_score in enumerate(total_per_class_training_f1score):
            metric_dict_key = 'Class {} Train F1-Score'.format(class_id)
            metric_dict[metric_dict_key] = round(train_f1_score.item(), 2)

        for class_id, val_f1_score in enumerate(total_per_class_validation_f1score):
            metric_dict_key = 'Class {} Val F1-Score'.format(class_id)
            metric_dict[metric_dict_key] = round(val_f1_score.item(), 2)

        # Reset the metric states
        training_accuracy.reset()
        validation_accuracy.reset()

        training_f1score.reset()
        validation_f1score.reset()
        per_class_training_f1score.reset()
        per_class_validation_f1score.reset()
        # ---Validation End---
            
        # Check if any of the per_class f1scores are above the threshold
        sorted_scores, class_indicies = torch.sort(total_per_class_training_f1score, descending = True) 
        indicies = (sorted_scores > args.f1_threshold).nonzero(as_tuple=False)
        logger.debug('Sorted F1 Scores: {}'.format(sorted_scores))

        # Check that warmup period has been met
        if args.warmup_period - current_epoch <= 0:
            logger.debug('Warmup Period has been reached')
            num_classes_removed_this_epoch = 0
            # Loop through the scores above the f1_threshold
            for index in indicies:
                # Retrieve the class label
                class_label = class_indicies[index].item() 
                # Double check that the class has not already been processed
                if not class_label in processed_classes:
                    logger.info('Class {} is above the threshold'.format(class_label))

                    # ---Class Removal Start---
                    if args.remove_classes:
                        logger.info('Preforming Data Dropout on Class {}'.format(class_label))
                        training_datamodule.remove_class_with_residue(class_label = class_label, residue = args.residue)

                        # Add class to the list of processed classes
                        removed_classes.append(class_label)
                        logger.debug('Removed classes: {}'.format(removed_classes))

                        num_classes_removed_this_epoch += 1

                    # ---Class Removal End---
                    processed_classes.append(class_label)
                    logger.debug('Processed classes: {}'.format(processed_classes))

                # Leave the loop if the policy is in place
                if args.removal_restriction != 0 and num_classes_removed_this_epoch - args.removal_restriction == 0:
                    logger.debug('Ht removal restriction policy of {} classes.'.format(args.removal_restriction))
                    break

            if args.remove_classes and args.reset_residue > 0 and current_epoch % args.reset_residue == 0 and len(removed_classes) > 0 and len(removed_classes) != num_classes:
                for class_label in removed_classes:
                    logger.info('Reshuffling Data Dropout Examples on Class {}'.format(class_label))
                    training_datamodule.remove_class_with_residue(class_label = class_label, residue = args.residue)

            # Reload the dataloaders
            if args.residue != 0 or len(removed_classes) != num_classes:
                logger.debug('Reinitializing the dataloaders')
                training_dataloader = training_datamodule.dataloader(shuffle = True)

        time_end = time.perf_counter()
        metric_dict['Epoch Time'] = round(time_end - time_start, 2)
        # Write metrics to a dataframe and save it off as a csv
        metric_df = pd.concat([metric_df, pd.DataFrame(metric_dict, index = [current_epoch])], ignore_index = True)
        metric_df.to_csv(metric_path, index = False)

        # Save the model to the checkpoint folder
        checkpoint_dictionary = {
            'epoch': current_epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(), 
            'loss': training_loss,
        }
        print('Previous validation:', previous_validation_accuracy)
        print('Current validation:', metric_dict['validation_accuracy'])
        # if metric_dict['validation_accuracy'] > previous_validation_accuracy:
        #     print('Saving best model')
        #     previous_validation_accuracy = metric_dict['validation_accuracy']
        torch.save(checkpoint_dictionary, checkpoint_path)
        logger.info('Finished Epoch {}'.format(current_epoch))

        # Check and see if all classes have had there filters frozen and if so end the training
        if len(removed_classes) == num_classes:
            break

logger.info('Finished Training')    