import os
import logging
import torch
from seml.utils import flatten
import wandb
import json
import numpy as np
from collections import OrderedDict
import pandas as pd
from pathlib import Path
import itertools
import random
from pprint import pprint

from train import train
from dataset import get_dataset

from models.model_loader import load_model
from models.ModifiedEvidentialN import ModifiedEvidentialNet

from utils.io_utils import DataWriter
from utils.metrics import accuracy, confidence, anomaly_detection, our_confidence, our_anomaly_detection
from utils.metrics import compute_X_Y_alpha, compute_X_Y_alpha_with_features_and_uncertainties, name2abbrv
from utils.ece import _ECELoss

create_model = {'menet': ModifiedEvidentialNet}
logging.getLogger().setLevel(logging.INFO)


def save_detailed_results(Y_all, X_all, alpha_pred_all, features_all, uncertainties_all, predicted_labels_all, 
                         save_dir, dataset_name, config_id, seed, is_ood=False):
    """
    Save detailed results including features, uncertainties, and predictions for each sample.
    
    Args:
        Y_all: True labels
        X_all: Input data  
        alpha_pred_all: Model predictions (alpha or softmax)
        features_all: Last layer features
        uncertainties_all: Dictionary of uncertainty values
        predicted_labels_all: Predicted class labels
        save_dir: Directory to save results
        dataset_name: Name of the dataset
        config_id: Configuration ID
        seed: Random seed
        is_ood: Whether this is OOD data
    """
    import os
    import pickle
    
    # Create save directory
    os.makedirs(save_dir, exist_ok=True)
    
    # Prepare filename
    ood_prefix = "ood_" if is_ood else "id_"
    filename = f"{ood_prefix}{dataset_name}_{config_id}_seed{seed}"
    
    # Prepare data to save
    results = {
        'true_labels': Y_all.numpy(),
        'predicted_labels': predicted_labels_all.numpy(),
        'features': features_all.numpy(),
        'alpha_predictions': alpha_pred_all.numpy(),
        'uncertainties': uncertainties_all,
        'dataset_name': dataset_name,
        'config_id': config_id,
        'seed': seed,
        'is_ood': is_ood
    }
    
    # Save as pickle file
    with open(os.path.join(save_dir, f"{filename}.pkl"), 'wb') as f:
        pickle.dump(results, f)
    
    # Also save as numpy files for easier access
    np.save(os.path.join(save_dir, f"{filename}_true_labels.npy"), Y_all.numpy())
    np.save(os.path.join(save_dir, f"{filename}_predicted_labels.npy"), predicted_labels_all.numpy())
    np.save(os.path.join(save_dir, f"{filename}_features.npy"), features_all.numpy())
    np.save(os.path.join(save_dir, f"{filename}_alpha_predictions.npy"), alpha_pred_all.numpy())
    
    # Save uncertainties as separate numpy files
    for uncertainty_name, uncertainty_values in uncertainties_all.items():
        np.save(os.path.join(save_dir, f"{filename}_uncertainty_{uncertainty_name}.npy"), uncertainty_values)
    
    print(f"Saved detailed results for {dataset_name} (OOD: {is_ood}) to {save_dir}/{filename}.*")


def main(config_dict):
    config_id = config_dict['config_id']
    suffix = config_dict['suffix']

    seeds = config_dict['seeds']

    dataset_name = config_dict['dataset_name']
    ood_dataset_names = config_dict['ood_dataset_names']
    split = config_dict['split']

    # Model parameters
    model_type = config_dict['model_type']
    name_model_list = config_dict['name_model']

    # Architecture parameters
    directory_model = config_dict['directory_model']
    architecture = config_dict['architecture']
    input_dims = config_dict['input_dims']
    output_dim = config_dict['output_dim']
    hidden_dims = config_dict['hidden_dims']
    kernel_dim = config_dict['kernel_dim']
    k_lipschitz = config_dict['k_lipschitz']

    # Training parameters
    max_epochs = config_dict['max_epochs']
    patience = config_dict['patience']
    frequency = config_dict['frequency']
    batch_size = config_dict['batch_size']
    lr_list = config_dict['lr']
    loss = config_dict['loss']
    lamb1_list = config_dict['lamb1_list']
    lamb2_list = config_dict['lamb2_list']
    kl_c_list = config_dict['kl_c']

    clf_type = config_dict['clf_type']
    fisher_c_list = config_dict['fisher_c']
    noise_epsilon = config_dict['noise_epsilon']

    # Mix parameters
    mix = config_dict.get('mix', False)
    mix_inter = config_dict.get('mix_inter', False)
    mix_inter_alpha = config_dict.get('mix_inter_alpha', 1.0)
    mix_inter_beta = config_dict.get('mix_inter_beta', 1.0)
    mix_noise = config_dict.get('mix_noise', False)
    noise_mix_alpha = config_dict.get('noise_mix_alpha', 1.0)
    noise_mix_beta = config_dict.get('noise_mix_beta', 1.0)
    noise_mix_ratio = config_dict.get('noise_mix_ratio', 1.0)

    # Optimizer and scheduler parameters
    optimizer_type = config_dict.get('optimizer_type', 'adam')
    use_cosine_annealing = config_dict.get('use_cosine_annealing', False)
    cosine_lr_min_ratio = config_dict.get('cosine_lr_min_ratio', 5e-6)
    
    # Sample-wise KL loss weighting parameter
    use_sample_wise_kl_weight = config_dict.get('use_sample_wise_kl_weight', False)
    kl_start_epoch_list = config_dict.get('kl_start_epoch', [100])

    model_dir = config_dict['model_dir']
    results_dir = config_dict['results_dir']
    stat_dir = config_dict['stat_dir']
    store_results = config_dict['store_results']
    store_stat = config_dict['store_stat']

    use_wandb = config_dict['use_wandb']

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    for setting in itertools.product(seeds, lr_list, fisher_c_list, name_model_list, lamb1_list, lamb2_list, kl_c_list, kl_start_epoch_list):
        (seed, lr, fisher_c, name_model, lamb1, lamb2, kl_c, kl_start_epoch) = setting

        random.seed(seed)
        os.environ['PYTHONHASHSEED'] = str(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

        ## Load dataset
        train_loader, val_loader, test_loader, N, output_dim = get_dataset(dataset_name, batch_size=batch_size, split=split, seed=seed)

        logging.info(f'Received the following configuration: seed {seed}')
        logging.info(f'DATASET | '
                     f'dataset_name {dataset_name} - '
                     f'ood_dataset_names {ood_dataset_names} - '
                     f'split {split}')

        ## Train or Load a pre-trained model
        if name_model is not None:
            logging.info(f'MODEL: {name_model}')
            config_dict = OrderedDict(name_model=name_model, model_type=model_type, seed=seed,
                                      dataset_name=dataset_name, split=split, loss=loss, epsilon=noise_epsilon)

            if use_wandb:
                run = wandb.init(project='Re-EDL', reinit=True,
                                 group=f'{dataset_name}_{ood_dataset_names}',
                                 name=f'{model_type}_{loss}_ep{noise_epsilon}_{seed}')

            model = load_model(directory_model=directory_model, name_model=name_model, model_type=model_type)
            stat_dir = stat_dir + f'{name_model}'

        else:
            logging.info(f'ARCHITECTURE | '
                         f' model_type {model_type} - '
                         f' architecture {architecture} - '
                         f' input_dims {input_dims} - '
                         f' output_dim {output_dim} - '
                         f' hidden_dims {hidden_dims} - '
                         f' kernel_dim {kernel_dim} - '
                         f' k_lipschitz {k_lipschitz}')
            logging.info(f'TRAINING | '
                         f' max_epochs {max_epochs} - '
                         f' patience {patience} - '
                         f' frequency {frequency} - '
                         f' batch_size {batch_size} - '
                         f' lr {lr} - '
                         f' loss {loss}')
            logging.info(f'MODEL PARAMETERS | '
                         f' clf_type {clf_type} - '
                         f' fisher_c {fisher_c} - '
                         f' kl_c {kl_c} - '
                         f' lamb1 {lamb1} -'
                         f' lamb2 {lamb2} - '
                         f' mix {mix} - '
                         f' mix_inter {mix_inter} - '
                         f' mix_noise {mix_noise} - '
                         f' optimizer_type {optimizer_type} - '
                         f' use_cosine_annealing {use_cosine_annealing}')

            config_dict = OrderedDict(model_type=model_type, seed=seed, dataset_name=dataset_name, split=split,
                                      architecture=architecture, input_dims=input_dims, output_dim=output_dim,
                                      hidden_dims=hidden_dims, kernel_dim=kernel_dim, k_lipschitz=k_lipschitz,
                                      max_epochs=max_epochs, patience=patience, frequency=frequency,
                                      batch_size=batch_size, clf_type=clf_type, lr=lr, loss=loss, fisher_c=fisher_c,
                                      kl_c=kl_c, lamb1=lamb1, lamb2=lamb2, mix=mix, mix_inter=mix_inter,
                                      mix_inter_alpha=mix_inter_alpha, mix_inter_beta=mix_inter_beta,
                                      mix_noise=mix_noise, noise_mix_alpha=noise_mix_alpha,
                                      noise_mix_beta=noise_mix_beta, noise_mix_ratio=noise_mix_ratio,
                                      optimizer_type=optimizer_type, use_cosine_annealing=use_cosine_annealing,
                                      use_sample_wise_kl_weight=use_sample_wise_kl_weight, kl_start_epoch=kl_start_epoch)

            if use_wandb:
                run = wandb.init(project='Re-EDL', reinit=True,
                                 group=f'{__file__}_{dataset_name}_{architecture}_{suffix}',
                                 name=f'{model_type}_{seed}_{loss}_lr{lr}_f{fisher_c}_{clf_type}')
                wandb.config.update(config_dict)

            filtered_config_dict = {'seed': seed,
                                    'architecture': architecture,
                                    'input_dims': input_dims,
                                    'output_dim': output_dim,
                                    'hidden_dims': hidden_dims,
                                    'kernel_dim': kernel_dim,
                                    'k_lipschitz': k_lipschitz,
                                    'batch_size': batch_size,
                                    'lr': lr,
                                    'loss': loss,
                                    'clf_type': clf_type,
                                    'fisher_c': fisher_c,
                                    'kl_c': kl_c,
                                    'lamb1': lamb1,
                                    'lamb2': lamb2,
                                    'mix': mix,
                                    'mix_inter': mix_inter,
                                    'mix_inter_alpha': mix_inter_alpha,
                                    'mix_inter_beta': mix_inter_beta,
                                    'mix_noise': mix_noise,
                                    'noise_mix_alpha': noise_mix_alpha,
                                    'noise_mix_beta': noise_mix_beta,
                                    'noise_mix_ratio': noise_mix_ratio,
                                    'optimizer_type': optimizer_type,
                                    'use_cosine_annealing': use_cosine_annealing,
                                    'num_epochs': max_epochs,
                                    'train_loader_len': len(train_loader),
                                    'cosine_lr_min_ratio': cosine_lr_min_ratio,
                                    'use_sample_wise_kl_weight': use_sample_wise_kl_weight,
                                    'kl_start_epoch': kl_start_epoch,
                                    }

            model = create_model[model_type](**filtered_config_dict)

            if torch.cuda.is_available():
                # torch.backends.cudnn.benchmark = True
                device_count = torch.cuda.device_count()
                if device_count > 1:
                    print(f"Multiple GPUs detected (n_gpus={device_count}), use all of them!")
                    model = torch.nn.DataParallel(model)
                    model = model.module

            full_config_name = ''
            for k, v in config_dict.items():
                if isinstance(v, dict):
                    v = flatten(v)
                    v = [str(val) for key, val in v.items()]
                    v = "-".join(v)
                if k != 'name_model':
                    full_config_name += str(v) + '-'
            full_config_name = full_config_name[:-1]

            model_path = model_dir + f'{seed}-{kl_c}'
            stat_dir = stat_dir + f'model-{full_config_name}'

            Path(model_dir).mkdir(parents=True, exist_ok=True)

            model.to(device)
            train(model, train_loader, val_loader, max_epochs=max_epochs, frequency=frequency, patience=patience,
                  model_path=model_path, full_config_dict=config_dict, use_wandb=use_wandb, device=device, 
                  output_dim=output_dim)
            # load the best model for test
            model.load_state_dict(torch.load(model_path + '_best')['model_state_dict'])

        ## Test model
        model.to(device)
        model.eval()

        mc_dropout = False
        mc_iter = 100
        if mc_dropout:
            print("MC Dropout Mode is ON!")
            from utils.enable_test_time_dropout import enable_test_time_dropout
            model, _ = enable_test_time_dropout(model)

        with torch.no_grad():
            if loss == 'MSE-softmax' or loss == 'CE-softmax':
                return_softmax = True
                metric_list = ['max_alpha']
            else:
                return_softmax = False
                metric_list = ['max_prob', 'max_alpha', 'alpha0',
                               'differential_entropy', 'mutual_information', 'edl_mpu']

            # when return_softmax=True, id_alpha_pred_all is actually softmax prob
            # Use the new function to get features and uncertainties
            id_Y_all, id_X_all, id_alpha_pred_all, id_features_all, id_uncertainties_all, id_predicted_labels_all = \
                compute_X_Y_alpha_with_features_and_uncertainties(model, test_loader, device, return_softmax=return_softmax,
                                  mc_dropout=mc_dropout, mc_iter=mc_iter, lamb1=lamb1, lamb2=lamb2)
            
            # Save ID dataset results
            save_detailed_results(
                Y_all=id_Y_all,
                X_all=id_X_all,
                alpha_pred_all=id_alpha_pred_all,
                features_all=id_features_all,
                uncertainties_all=id_uncertainties_all,
                predicted_labels_all=id_predicted_labels_all,
                save_dir=f'{stat_dir}/detailed_results',
                dataset_name=dataset_name,
                config_id=config_id,
                seed=seed,
                is_ood=False
            )

            calculate_ece_and_brier_score = False
            if calculate_ece_and_brier_score:
                labels = id_Y_all
                logits = id_alpha_pred_all
                if return_softmax:
                    prob = logits
                else:
                    prob = logits / torch.sum(logits, dim=-1, keepdim=True)
                ece_criterion = _ECELoss().cuda()
                ece = ece_criterion(prob, labels).item() * 100
                print('Expected Callibration Error: %.2f' % ece)

                if return_softmax:
                    prob = logits
                else:
                    prob = logits / torch.sum(logits, dim=-1, keepdim=True)
                labels_onehot = torch.eye(id_alpha_pred_all.shape[-1])[labels]
                brier_score = torch.norm(prob - labels_onehot, dim=1, p=2).mean() * 100
                print('Brier Score: %.2f' % brier_score)
                continue

            # Save metrics
            metrics = {}
            scores = {}
            ood_scores = {}
            metrics['id_accuracy'] = accuracy(Y=id_Y_all, alpha=id_alpha_pred_all).tolist()

            for name in metric_list:
                abb_name = name2abbrv[name]
                save_path = None
                if store_stat:
                    save_path = f'{stat_dir}/{config_id}_id_{abb_name}.csv'
                    Path(stat_dir).mkdir(parents=True, exist_ok=True)

                if model_type == 'evnet' or model_type == 'duq':
                    aupr, auroc, score = confidence(Y=id_Y_all, alpha=id_alpha_pred_all, uncertainty_type=name,
                                                    save_path=save_path, return_scores=True)
                elif model_type == 'menet' or model_type == 'ablation':
                    aupr, auroc, score = our_confidence(Y=id_Y_all, alpha=id_alpha_pred_all, uncertainty_type=name,
                                                        save_path=save_path, return_scores=True)
                else:
                    raise NotImplementedError
                metrics[f'id_{abb_name}_apr'], metrics[f'id_{abb_name}_auroc'] = aupr * 100, auroc * 100
                
                scores[f'{abb_name}'] = score

            ood_dataset_loaders = {}
            for ood_dataset_name in ood_dataset_names:
                config_dict['ood_dataset_name'] = ood_dataset_name
                _, _, ood_test_loader, _, _ = get_dataset(ood_dataset_name, batch_size=batch_size,
                                                            split=split, seed=seed)
                ood_dataset_loaders[ood_dataset_name] = ood_test_loader

                ood_Y_all, ood_X_all, ood_alpha_pred_all, ood_features_all, ood_uncertainties_all, ood_predicted_labels_all = \
                    compute_X_Y_alpha_with_features_and_uncertainties(
                        model, ood_test_loader, device, noise_epsilon=noise_epsilon, return_softmax=return_softmax,
                        mc_dropout=mc_dropout, mc_iter=mc_iter, lamb1=lamb1, lamb2=lamb2)
                
                # Save OOD dataset results
                save_detailed_results(
                    Y_all=ood_Y_all,
                    X_all=ood_X_all,
                    alpha_pred_all=ood_alpha_pred_all,
                    features_all=ood_features_all,
                    uncertainties_all=ood_uncertainties_all,
                    predicted_labels_all=ood_predicted_labels_all,
                    save_dir=f'{stat_dir}/detailed_results',
                    dataset_name=ood_dataset_name,
                    config_id=config_id,
                    seed=seed,
                    is_ood=True
                )

                if ood_dataset_name == dataset_name and noise_epsilon != 0:
                    metrics['ood_accuracy'] = accuracy(Y=ood_Y_all, alpha=ood_alpha_pred_all).tolist()

                for name in metric_list:
                    abb_name = name2abbrv[name]
                    save_path = None
                    if store_stat:
                        save_path = f'{stat_dir}/{config_id}_ood_{abb_name}.csv'
                    if model_type == 'evnet' or model_type == 'duq':
                        aupr, auroc, _, ood_score = anomaly_detection(alpha=id_alpha_pred_all, ood_alpha=ood_alpha_pred_all,
                                                                      uncertainty_type=name, save_path=save_path, return_scores=True)
                    elif model_type == 'menet' or model_type == 'ablation':
                        aupr, auroc, _, ood_score = our_anomaly_detection(alpha=id_alpha_pred_all, ood_alpha=ood_alpha_pred_all,
                                                                          uncertainty_type=name, save_path=save_path, return_scores=True)
                    else:
                        raise NotImplementedError
                    metrics[f'ood_{abb_name}_apr'], metrics[f'ood_{abb_name}_auroc'] = aupr * 100, auroc * 100
                    ood_scores[f'{abb_name}'] = ood_score

                print("Metrics: ")
                pprint(metrics)

                if use_wandb:
                    data_df = pd.DataFrame(data=[metrics])
                    wandb_table = wandb.Table(dataframe=data_df)
                    wandb.log({'{}'.format(ood_dataset_name): wandb_table})

                if store_results:
                    row_dict = config_dict.copy()
                    for k, v in config_dict.items():
                        if isinstance(v, list):
                            row_dict[k] = str(v)

                    row_dict.update(metrics)  # shallow copy

                    Path(results_dir).mkdir(parents=True, exist_ok=True)
                    data_writer = DataWriter(dump_period=1)
                    csv_file = f'{results_dir}/{config_id}.csv'
                    data_writer.add(row_dict, csv_file)

        if use_wandb:
            run.finish()

    return


if __name__ == '__main__':
    use_argparse = True

    if use_argparse:
        import argparse
        my_parser = argparse.ArgumentParser()
        my_parser.add_argument('--configid', action='store', type=str, required=True)
        my_parser.add_argument('--suffix', type=str, default='debug', required=False)
        args = my_parser.parse_args()
        args_configid = args.configid
        args_suffix = args.suffix
    else:
        args_configid = 'test'
        args_suffix = 'debug'

    if '/' in args_configid:
        args_configid_split = args_configid.split('/')
        my_config_id = args_configid_split[-1]
        config_tree = '/'.join(args_configid_split[:-1])
    else:
        my_config_id = args_configid
        config_tree = ''

    PROJPATH = os.getcwd()
    cfg_dir = f'{PROJPATH}/configs'
    os.makedirs(cfg_dir, exist_ok=True)
    cfg_path = f'{PROJPATH}/configs/{config_tree}/{my_config_id}.json'
    logging.info(f'Reading Configuration from {cfg_path}')

    with open(cfg_path) as f:
        proced_config_dict = json.load(f)

    proced_config_dict['config_id'] = my_config_id
    proced_config_dict['suffix'] = args_suffix

    proced_config_dict['model_dir'] = f'{PROJPATH}/saved_models/{my_config_id}/'
    proced_config_dict['results_dir'] = f'{PROJPATH}/saved_models/{my_config_id}/'
    proced_config_dict['stat_dir'] = f'{PROJPATH}/results/{config_tree}_stat/'

    main(proced_config_dict)

