import os
import argparse
import importlib
import numpy as np
from collections import OrderedDict

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from models.relu_mlp import Net
from datasets.get_datasets import get_datasets
from utils.data_logger import CSVLogger
from utils.convergence_check import ConvergenceCheck
from utils.reproducibility import set_random_seed, load_config, save_config

import core.regularizer as regularizer
from core.misprediction_analysis import MispredictionAnalyzer
from core.structure_manipulation import MLPManipulator
from core.folding_search import find_decision_boundary, find_closest_point_on_decision_boundary
from core.parameter_optimization import calculate_gradients_over_inputs, calculate_gradients_over_predictions, initialize_wB_new, initialize_wA_new_and_bA_new

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
CONFIGS_DIR = os.path.join(BASE_DIR, 'configs')
DATASET_CONFIGS = os.path.join(CONFIGS_DIR, 'datasets.yml')
LOSS_FUNCTIONS_CONFIGS = os.path.join(CONFIGS_DIR, 'loss_functions.yml')

CHECKPOINT_NAME = 'hidden_dim={}_epoch={}.pth'
INTERMEDIATE_NAME = 'hidden_dim={}.npz'
GROWTH_LOG_NAME = 'growth_log.csv'
EXP_CONFIGS_NAME = 'exp_config.yml'


def grow_model(model, dataset, criterion, intermediate_path, min_mistake_proportion=0.0001, **kwargs):
    ''' broaden the hidden layer by folding decision boundary along typical mistakes '''

    model.cpu()

    # find typical mistakes given the current model
    analyzer = MispredictionAnalyzer(model, dataset, criterion, sample_size=None, task='classification')
    typical_mistakes = analyzer.find_typical_mistakes(
        n_clusters=20,
        standardization=False,
        softmax=True,
        method='largest_cluster',
        representative='cluster_centroid'
    )

    while True:
        try:
            typical_mispredicted_samples, (x_target, y_target), offset = next(typical_mistakes)
        except StopIteration:
            print("No more typical mistakes worth analysis.")
            return model, None, None

        proportion = len(typical_mispredicted_samples) / len(dataset)
        if proportion < min_mistake_proportion:
            print("Mistake proportion too small ({:.4f}% < {:.4}%). Moving on to the next cluster of mistakes...".format(proportion * 100, min_mistake_proportion * 100))
            continue
        
        print("Processing cluster with {} mis-predicted instances ({:.4f}%)...".format(len(typical_mispredicted_samples), proportion * 100))
        break

    x_target = torch.tensor(x_target)
    y_target = torch.tensor(y_target)

    # verify misprediction
    model.eval()
    with torch.no_grad():
        y_pred = model(x_target.unsqueeze(0)).argmax().detach().item()
    assert int(y_pred) != int(y_target), "Unexpected error: target input {} correctly predicted.".format(x_target)

    # find decision boundary
    architecture = MLPManipulator(model, **kwargs)

    W_activated, b_activated = architecture.find_activated_linear_segment(x_target, model)

    w_dec_bound, b_dec_bound = find_decision_boundary(int(y_target), y_pred, W_activated, b_activated)

    x_fold = find_closest_point_on_decision_boundary(x_target, w_dec_bound, b_dec_bound)

    # calculate gradients
    dL_dx_before = calculate_gradients_over_inputs(
        model, x_target.unsqueeze(0), y_target.unsqueeze(0), criterion
    ).squeeze(0)
    dL_df_before = calculate_gradients_over_predictions(
        model, x_target.unsqueeze(0), y_target.unsqueeze(0), criterion
    ).squeeze(0)

    # extract original parameters
    WA = architecture.get_parameters(model, layer_number=1, operation_type='Linear', parameter_type='weight', requires_grad=True)
    bA = architecture.get_parameters(model, layer_number=1, operation_type='Linear', parameter_type='bias', requires_grad=True)
    WB = architecture.get_parameters(model, layer_number=2, operation_type='Linear', parameter_type='weight', requires_grad=True)

    # calculate initial values for the newly added parameters
    wB_new = initialize_wB_new(WB)
    wA_new, bA_new = initialize_wA_new_and_bA_new(x_fold, x_target, dL_dx_before, dL_df_before, wB_new)

    # broaden model
    new_model = architecture.broaden_hidden_layer(model, wA_new, bA_new, wB_new)
    print("Network broaden.")

    # save intermediate results in an .npz file
    xs_cluster, ys_cluster = zip(*typical_mispredicted_samples)
    np.savez_compressed(
        intermediate_path,
        x_target=x_target.numpy(),
        y_target=y_target,
        y_predict=y_pred,
        xs_cluster=xs_cluster,
        ys_cluster=ys_cluster,
        x_fold=x_fold.numpy()
    )
    print("Intermediates saved to: {}".format(intermediate_path))

    return new_model, x_target, x_fold


def train_model(model, device, data_loader, criterion, optimizer, regularization=None, start_epoch=1, max_epoch=1000, print_freq=1, checkpoint_freq=1, checkpoint_dir='.'):
    ''' train model with optional regularization until convergence '''
    # validate regularization configurations
    if regularization:
        if not isinstance(regularization, dict):
            raise TypeError("Invalid value for parameter \"regularization\": expected dict, got {}.".format(type(regularization)))
        for name, params in regularization.items():
            if not hasattr(regularizer, name):
                raise ValueError("Unknown regularization: \"{}\"".format(name))
            if not isinstance(params, dict):
                raise TypeError("Invalid value for regularization \"{}\": expected dict, got {}.".format(name, type(params)))

    # setup model on device
    model.to(device)

    # create convergence checker
    is_converge = ConvergenceCheck(optimization='min', patience=10)

    # start training
    model.train()
    for epoch in range(start_epoch, max_epoch + 1):
        epoch_correct = 0
        epoch_cost = 0.0
        print("Epoch {} / {}:".format(epoch, max_epoch))

        for batch_index, (inputs, targets) in enumerate(data_loader, start=1):
            # move data to device
            inputs, targets = inputs.to(device), targets.to(device)

            # forward pass
            optimizer.zero_grad()
            outputs = model(inputs)
            cost = criterion(outputs, targets)
            
            # message of batch status
            status = '\r    batch: {} / {} | cost: {:.4f}'.format(batch_index, len(data_loader), cost.item())

            # apply regularization penalties, if specified
            if regularization:
                for name, kwargs in regularization.items():
                    penalty = getattr(regularizer, name)(model, **kwargs)
                    cost += penalty
                    status += ' | {}: {:.4f}'.format(name, penalty.item())

            # backward propagation
            cost.backward()
            optimizer.step()

            # update statistics
            epoch_cost += cost.item()
            preds = outputs.argmax(dim=1, keepdim=True)
            epoch_correct += preds.eq(targets.view_as(preds)).sum().item()

            # print batch status as per print_freq
            if batch_index % print_freq == 0:
                print("\r" + status, end='', flush=True)

        # print end-of-epoch statistics
        print('\r{}'.format(' ' * 120), end='')
        print('\r    epoch average cost: {:.4f}'.format(epoch_cost / len(data_loader)))
        print('    epoch accoracy:     {:.2f}%'.format(100.0 * epoch_correct / len(data_loader.dataset)))

        checkpoint_path = None

        # save checkpoint
        if checkpoint_dir and (epoch % checkpoint_freq == 0 or epoch == max_epoch):
            checkpoint_path = save_model(model, epoch, checkpoint_dir)

        # early stopping condition 1: all predictions are correct
        if epoch_correct == len(data_loader.dataset):
            print("Perfect accuracy achieved! Stopping training.")
            break

        # early stopping condition 2: accuracy no longer raise
        if is_converge(epoch_cost):
            print("Convergence reached. Stopping training.")
            break

    # not yet trained to perfect
    return checkpoint_path if checkpoint_path is not None else save_model(model, epoch, checkpoint_dir)


def save_model(model, epoch, dir_path, checkpoint_name=CHECKPOINT_NAME):
    ''' save both the attributes and the paramters of a trained model to a specific path '''
    path = os.path.join(dir_path, checkpoint_name.format(model.hidden_dim, epoch))
    
    torch.save({
        'epoch': epoch,
        'in_size': model.in_size,
        'out_size': model.out_size,
        'hidden_dim': model.hidden_dim,
        'state_dict': model.state_dict(),
    }, path)
    
    print("Checkpoint saved to: ", path)

    return path


def main(args):
    ''' main training workflow for growing neural network model '''
    # set global random seeds for reproducibility
    set_random_seed(args.random_seed)
    print("Random seed: {}".format(args.random_seed))

    # specify training device
    device = torch.device('cuda:{}'.format(args.gpu_id) if torch.cuda.is_available() else 'cpu')
    print("Using device: {}".format(device))

    # load dataset configuration ################# BUG: load_checkpoint ########################################################################################################
    assert os.path.exists(DATASET_CONFIGS), "Dataset configuration file not found: {}".format(DATASET_CONFIGS)
    dataset_config = load_config(DATASET_CONFIGS).get(args.dataset)
    if not dataset_config:
        valid_datasets = ', '.join(load_config(dataset_config_path).keys())
        raise ValueError("Invalid dataset \"{}\". Valid options are: {}".format(args.dataset, valid_datasets))

    # get dataset object with corresponding preprocessing and train-test sets splitting
    train_dataset, test_dataset = get_datasets(dataset_config)

    # create data loaders
    train_data_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,
                                   num_workers=args.num_workers, pin_memory=True)
    test_data_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False,
                                  num_workers=args.num_workers, pin_memory=True)

    # define loss function
    assert os.path.exists(LOSS_FUNCTIONS_CONFIGS), "Loss function configuration file not found: {}".format(LOSS_FUNCTIONS_CONFIGS)
    loss_function_configs = load_config(LOSS_FUNCTIONS_CONFIGS).get(args.loss_function)
    criterion = getattr(importlib.import_module(loss_function_configs['module_name']), loss_function_configs['function_name'])()

    # retrieve model dimensions
    input_size = dataset_config['size'] if isinstance(dataset_config['size'], int) else np.prod(dataset_config['size'])
    num_classes = dataset_config['num_classes']

    # experiment setup: either create new experiment or continue from checkpoint
    exp_path = ckpt_path = configs = None
    start_epoch = 1
    hidden_dim = 0
    if args.export_dir and not args.load_checkpoint:
        # validate experiment directory
        parent_dir = os.path.dirname(args.export_dir)
        if parent_dir and not os.path.exists(parent_dir):
            raise FileNotFoundError("Parent directory \"{}\" does not exist. Cannot create experiment folder.".format(parent_dir))
        if os.path.exists(args.export_dir):
            raise FileExistsError("Experiment folder \"{}\" already exists. Choose a new name or delete existing directory.".format(args.export_dir))

        # create new experiment folder
        os.mkdir(args.export_dir)
        exp_path = args.export_dir
        ckpt_path = os.path.join(exp_path, "checkpoints")
        os.mkdir(ckpt_path)

        # save experiment configurations
        configs = vars(args)
        configs.pop('export_dir')
        configs.pop('load_checkpoint')
        save_config(configs, os.path.join(exp_path, EXP_CONFIGS_NAME))

        # initialize model with base hidden dimension = number of classes
        hidden_dim = num_classes
        model = Net(in_size=input_size, out_size=num_classes, hidden_dim=hidden_dim)
        print("\nInitialized base model with {} hidden-layer neurons.".format(hidden_dim))

        # save initial model state
        save_model(model, epoch=0, dir_path=ckpt_path)

        print("\nPre-training base model to convergence...")

    elif args.load_checkpoint and not args.export_dir:
        # ensure the existence of the checkpoint file
        if not os.path.isfile(args.load_checkpoint):
            raise FileNotFoundError("Checkpoint file not found: {}".format(args.load_checkpoint))

        # validate checkpoint file extension and experiment folder structure
        ckpt_path, ckpt_name = os.path.split(args.load_checkpoint)
        exp_path, ckpt_folder = os.path.split(ckpt_path)
        configs_path = os.path.join(exp_path, EXP_CONFIGS_NAME)
        if not (os.path.splitext(ckpt_name)[1].lower() in ['.pt', '.pth']):
            raise ValueError("Invalid checkpoint name \"{}\": the model must have either \".pt\" or \".pth\" an extension.".format(ckpt_name))
        if (ckpt_folder != 'checkpoints') and not os.path.isdir(exp_path) and not os.path.exists(configs_path):
            raise ValueError("Failed to retrieve experimental configurations from the specified checkpoint path: {}".format(ckpt_path))

        # validate whether all arguments in the namespace are consistent with the original experimental configurations
        configs = load_config(configs_path)
        for key in configs:
            if configs[key] != getattr(args, key):
                raise RuntimeError("Configuration \"{}\" mismatch! Expected {} but got {}.".format(key, configs[key], getattr(args, key)))
        
        # load existing checkpoint's attributes and parameters
        checkpoint = torch.load(args.load_checkpoint, weights_only=False)

        # verify checkpoint compatibility with current dataset
        if input_size != checkpoint['in_size']:
            raise RuntimeError("Input dimension mismatch! Dataset requires {} but checkpoint has {}.".format(input_size, checkpoint['in_size']))
        if num_classes != checkpoint['out_size']:
            raise RuntimeError("Output dimension mismatch! Dataset has {} classes but checkpoint expects {}.".format(num_classes, checkpoint['out_size']))

        # load model with checkpoint parameters
        hidden_dim = checkpoint['hidden_dim']
        model = Net(in_size=input_size, out_size=num_classes, hidden_dim=hidden_dim)
        model.load_state_dict(checkpoint['state_dict'])

        start_epoch = checkpoint['epoch'] + 1
        print("\nResuming training from checkpoint \"{}\" at epoch {}...".format(args.load_checkpoint, start_epoch))

    else:
        raise RuntimeError("Must specify either --export_dir (new experiment) or --load_checkpoint (resume training).")

    #########################################################################################################################################################################
    #                                                                  Pre-training Phase                                                                                   #
    #########################################################################################################################################################################

    # pre-train base model or resume fine-tuning checkpoint model until convergence
    optimizer = torch.optim.Adam(model.parameters(), lr=args.learn_rate, weight_decay=args.weight_decay)
    regularization = {
        'diversity_loss': {'coef': 0.5},
        'orthogonality_loss': {'coef': 0.5},
    }
    train_checkpoint_path = train_model(
        model,
        device,
        train_data_loader,
        criterion,
        optimizer,
        regularization=regularization,
        start_epoch=start_epoch,
        max_epoch=args.max_epoch,
        checkpoint_dir=ckpt_path
    )

    #########################################################################################################################################################################
    #                                                                 Network Growing Phase                                                                                 #
    #########################################################################################################################################################################

    # create folder to store growth intermediates if not exist
    intermediate_dir = os.path.join(exp_path, 'intermediates')
    if not os.path.exists(intermediate_dir):
        os.mkdir(intermediate_dir)

    if not args.stop_growing:
        # main training loop: grow model until perfect training accuracy is achieved
        is_perfect = False
        while not is_perfect:
            assert hidden_dim == model.hidden_dim, "Unknown error: current hidden_dim {} does not match model {}".format(hidden_dim, model.hidden_dim)

            # specify the path to the file for saving intermediates during network growth
            intermediate_path = os.path.join(intermediate_dir, INTERMEDIATE_NAME.format(hidden_dim))

            # expand network architecture
            print("\nGrowing hidden-layer dimension: {} -> {} ...".format(hidden_dim, hidden_dim + 1))
            model, target_point, folding_point = grow_model(model, train_dataset, criterion, intermediate_path=intermediate_path, input_dim=input_size)
            hidden_dim += 1

            # save grown model
            grown_checkpoint_path = save_model(model, epoch=0, dir_path=ckpt_path)

            # save growth log
            growth_log = OrderedDict([
                ('before_growth', train_checkpoint_path),
                ('after_growth', grown_checkpoint_path),
                ('intermediates', intermediate_path),
            ])
            log_path = os.path.join(exp_path, GROWTH_LOG_NAME)
            logger = CSVLogger(log_path, column_names=list(growth_log.keys()), append=os.path.exists(log_path))
            logger.add_row(growth_log)
            print("Growth log saved to:", log_path)

            # setup regularization for the newly added parameters
            regularization = {
                'fold_plane_alignment_loss': {'target_point': target_point.to(device),
                                              'folding_point': folding_point.to(device),
                                              'coef': 1000,
                                              'layer_name': 'fc_layer_1',
                                              'new_param_index': -1,
                }
            }

            # re-define optimizer for expanded network model
            optimizer = torch.optim.Adam(model.parameters(), lr=args.learn_rate, weight_decay=args.weight_decay)

            # fine-tune expanded model
            print("\nFine-tuning expanded network with {} hidden-layer neurons...".format(hidden_dim))
            train_checkpoint_path = train_model(model,
                device,
                train_data_loader,
                criterion,
                optimizer,
                regularization=regularization,
                max_epoch=args.max_epoch,
                checkpoint_dir=ckpt_path
            )

        print("\nTraining completed successfully! Achieved 100% training accuracy!")


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Train neural network with dynamic architecture growth')
    parser.add_argument('-exp', '--export_dir', type=str, default=None,
                        help='path to directory to be created for new experiment (exclusive with --load_checkpoint)')
    parser.add_argument('-load', '--load_checkpoint', type=str, default=None,
                        help='path to an existing checkpoint file to resume training (exclusive with --export_dir)')
    parser.add_argument('--stop_growing', action='store_true',
                        help='stop after model pre-training without growth')
    parser.add_argument('-data', '--dataset', type=str, default='mnist',
                        help='name of the dataset for training')
    parser.add_argument('-loss', '--loss_function', type=str, default='cross-entropy',
                        help='name of the loss function to be used for training')
    parser.add_argument('-seed', '--random_seed', type=int, default=0,
                        help='global random seed for the reproducibility of the current experiment')
    parser.add_argument('-lr', '--learn_rate', type=float, default=0.001,
                        help='learning rate for base model training and expanded model fine-tuning after growth')
    parser.add_argument('-wd', '--weight_decay', type=float, default=0.001,
                        help='weight decay (L2 regularization) during optimization')
    parser.add_argument('-batch', '--batch_size', type=int, default=128,
                        help='number of samples per training batch')
    parser.add_argument('-epoch', '--max_epoch', type=int, default=100,
                        help='maximum epochs for fine-tuning per growth iteration')
    parser.add_argument('-gpu', '--gpu_id', type=int, default=0,
                        help='id of the GPU to be used')
    parser.add_argument('-worker', '--num_workers', type=int, default=0,
                        help='number of subprocesses for data loading') 
    args = parser.parse_args()
    main(args)