import os
import argparse
import importlib
import numpy as np
from collections import OrderedDict

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import core.regularizer as regularizer
from datasets.get_datasets import get_datasets
from models.relu_mlp import Net
from utils.data_logger import CSVLogger
from utils.efficiency import find_max_batch_size
from utils.reproducibility import set_random_seed, load_config

EVALUATION_LOG_NAME = 'evaluations.csv'
EXP_CONFIGS_NAME = 'exp_config.yml'
DATASET_CONFIGS = os.path.join('.', 'src', 'configs', 'datasets.yml')
LOSS_FUNCTIONS_CONFIGS = os.path.join('.', 'src', 'configs', 'loss_functions.yml')

def evaluate_dataset_loss_and_accuracy(model, device, data_loader, criterion, print_freq=1):
    ''' evaluate model on dataset with loss and accuracy metrics '''

    # setup model on device
    model.to(device)

    # ensure batch outputs' losses will not be reduced
    if hasattr(criterion, 'reduction'):
        criterion.reduction = 'none'
    
    # start inference
    model.eval()
    all_losses = []
    total_correct = 0
    with torch.no_grad():

        for batch_index, (batch_inputs, batch_targets) in enumerate(data_loader, start=1):
            # move data to device
            batch_inputs, batch_targets = batch_inputs.to(device), batch_targets.to(device)

            # forward pass
            batch_outputs = model(batch_inputs)

            # record loss function values
            batch_losses = criterion(batch_outputs, batch_targets)
            all_losses += batch_losses.cpu().tolist()

            # update number of correct predictions
            preds = batch_outputs.argmax(dim=1, keepdim=True)
            total_correct += preds.eq(batch_targets.view_as(preds)).sum().item()

            # print batch status as per print_freq
            if batch_index % print_freq == 0:
                print("\rbatch: {} / {}".format(batch_index, len(data_loader)), end='', flush=True)

    # calculate average loss
    assert len(all_losses) == len(data_loader.dataset), "Unknown error: number of recorded losses ({}) mismatch the total number of data in the dataset ({}).".format(len(epoch_losses), len(data_loader.dataset))
    avg_loss = sum(all_losses) / len(all_losses)

    # calculate overall accuracy
    accuracy = 100.0 * total_correct / len(data_loader.dataset)

    # print end-of-epoch statistics
    print('\r{}'.format(' ' * 120), end='')
    print("\r    average loss: {:.4f}, accuracy: {:.2f}%".format(avg_loss, accuracy))

    return avg_loss, accuracy


def evalutate_regularization(model, regularization):
    ''' evaluate model on regularization terms without coefficients '''
    evaluations = {}

    for name, kwargs in regularization.items():
        penalty = getattr(regularizer, name)(model, **kwargs)
        evaluations[name] = penalty

    return evaluations


def main(args):
    ''' evaluate saved model on training and testing sets '''
    
    # validate checkpoint file
    if not os.path.isfile(args.load_checkpoint):
        raise FileNotFoundError("Checkpoint file not found: {}".format(args.load_checkpoint))

    # validate checkpoint file extension
    ckpt_path, ckpt_name = os.path.split(args.load_checkpoint)
    if not (os.path.splitext(ckpt_name)[1].lower() in ['.pt', '.pth']):
        raise ValueError("Invalid checkpoint name \"{}\": the model must have either \".pt\" or \".pth\" an extension.".format(ckpt_name))

    # validate experiment folder structure and experimental configurations
    exp_path, ckpt_folder = os.path.split(ckpt_path)
    configs_path = os.path.join(exp_path, EXP_CONFIGS_NAME)
    if (ckpt_folder != 'checkpoints') and not os.path.isdir(exp_path) and not os.path.exists(configs_path):
        raise ValueError("Failed to retrieve experimental configurations from the specified checkpoint path: {}".format(ckpt_path))

    print("\nLoading checkpoint {}...".format(args.load_checkpoint))

    # load original experimental configurations
    configs = load_config(configs_path)

    # set global random seeds for reproducibility
    set_random_seed(configs['random_seed'])

    # determine training device
    device = torch.device('cuda:{}'.format(args.gpu_id) if torch.cuda.is_available() else 'cpu')

    # define loss function
    assert os.path.exists(LOSS_FUNCTIONS_CONFIGS), "Loss function configuration file not found: {}".format(LOSS_FUNCTIONS_CONFIGS)
    loss_function_configs = load_config(LOSS_FUNCTIONS_CONFIGS).get(configs['loss_function'])
    criterion = getattr(importlib.import_module(loss_function_configs['module_name']), loss_function_configs['function_name'])()

    # load dataset configuration
    assert os.path.exists(DATASET_CONFIGS), "Dataset configuration file not found: {}".format(DATASET_CONFIGS)
    dataset_config = load_config(DATASET_CONFIGS).get(configs['dataset'])
    assert dataset_config, "Unknown error: dataset name \"{}\" recorded in the experimental configurations is invalid.".format(configs['dataset'])
    
    # get dataset object with corresponding preprocessing and train-test sets splitting
    train_dataset, test_dataset = get_datasets(dataset_config)

    # retrieve model dimensions
    input_size = dataset_config['size'] if isinstance(dataset_config['size'], int) else np.prod(dataset_config['size'])
    num_classes = dataset_config['num_classes']

    # load existing checkpoint's attributes and parameters
    checkpoint = torch.load(args.load_checkpoint, weights_only=False)
    assert checkpoint['in_size'] == input_size, "Unknown error: loaded checkpoint model's input dimension mismatch dataset requirements."
    assert checkpoint['out_size'] == num_classes, "Unknown error: loaded checkpoint model's output dimension mismatch dataset requirements."

    # load model with checkpoint parameters
    hidden_dim = checkpoint['hidden_dim']
    model = Net(in_size=input_size, out_size=num_classes, hidden_dim=hidden_dim)
    model.load_state_dict(checkpoint['state_dict'])

        # find the maximum safe batch size for inference on the specified device if not specified
    max_infer_batch_size = args.max_batch_size if args.max_batch_size is not None else find_max_batch_size(model, device, train_dataset)

    # create data loaders
    train_data_loader = DataLoader(train_dataset, batch_size=max_infer_batch_size, shuffle=False,
                                   num_workers=args.num_workers, pin_memory=True)
    test_data_loader = DataLoader(test_dataset, batch_size=max_infer_batch_size, shuffle=False,
                                  num_workers=args.num_workers, pin_memory=True)

    # evaluate on training set
    print("\nEvaluating on training set...")
    train_loss, train_acc = evaluate_dataset_loss_and_accuracy(model, device, train_data_loader, criterion)

    # evaluate on testing set
    print("\nEvaluating on testing set...")
    test_loss, test_acc = evaluate_dataset_loss_and_accuracy(model, device, test_data_loader, criterion)

    # evaluate regularization terms
    regularization = {
        'diversity_loss': {'coef': 1.0},
        'orthogonality_loss': {'coef': 1.0},
    }
    print("\nEvaluating on regularization terms: {}...".format(', '.join(regularization.keys())))
    regularizer_evaluations = evalutate_regularization(model, regularization)

    # prepare results
    results = OrderedDict([
        ('checkpoint', os.path.basename(args.load_checkpoint)),
        ('train_accuracy', train_acc),
        ('train_loss', train_loss),
        ('test_accuracy', test_acc),
        ('test_loss', test_loss),
        ('diversity_loss', regularizer_evaluations.get('diversity_loss', float('nan'))),
        ('orthogonality_loss', regularizer_evaluations.get('orthogonality_loss', float('nan'))),
    ])

    log_path = os.path.join(exp_path, EVALUATION_LOG_NAME)
    logger = CSVLogger(log_path, column_names=list(results.keys()), append=os.path.exists(log_path))
    logger.add_row(results)

    print("\nEvaluation complete. Results saved to:", log_path)
    print("Summary:")
    for k, v in results.items():
        if isinstance(v, float):
            print("    {:<20} {:.4f}".format(k + ":", v))
        else:
            print("    {:<20} {}".format(k + ":", v))




if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Test model')
    parser.add_argument('-load', '--load_checkpoint', type=str, default=None, required=True,
                        help='path to an existing checkpoint file for testing')
    parser.add_argument('-gpu', '--gpu_id', type=int, default=0,
                        help='id of the GPU to be used')
    parser.add_argument('-worker', '--num_workers', type=int, default=0,
                        help='number of subprocesses for data loading')
    parser.add_argument('-batch', '--max_batch_size', type=int, default=None,
                        help='maximum safe batch size for inference')
    args = parser.parse_args()
    main(args)