import json
import torch
import numpy as np
import torch.backends.cudnn as cudnn
import time
import yaml
import argparse

# Import data utilities
import torch.utils.data as data
import data.active_learning.active_learning as active_learning
from data.ambiguous_mnist.ambiguous_mnist_dataset import AmbiguousMNIST
from data.fast_mnist import create_MNIST_dataset

# Import network architectures
from models.crenet_res18 import CreResNet18

# Import train and test utils
from utils.train_crenets_utils import train_single_epoch, model_save_name

from metrics.classification_metrics_credal import test_classification_net_ensemble


from utils.credal_ensemble_utils import credal_ensemble_forward_pass

# Mapping model name to model function
# models = {"resnet18": resnet18}

def load_config(yaml_file):
    with open(yaml_file, 'r') as file:
        config = yaml.load(file, Loader=yaml.SafeLoader)
    return config


# Accept a YAML file as a command-line argument
parser = argparse.ArgumentParser(description='Process parameters from a YAML file.')
parser.add_argument('config_file', type=str, help='Path to the YAML configuration file')
args = parser.parse_args()

config = load_config(args.config_file)

args_uncs = config['UnType']
args_model_name = "cre_resnet18"
predefined_seed = 0
# predefined_seed = 1
# predefined_seed = 2
# predefined_seed = 3
# predefined_seed = 4

# enable_ambiguous = False
enable_ambiguous = True
dataset_root = "data/ambiguous_mnist/"
args_al_type = "ensemble"
args_num_ensemble = 5
# args_num_ensemble = 3


args_threshold = 1.0
args_subsample = 1000

args_num_initial_samples = 20
args_max_training_samples = 300
args_acquisition_batch_size = 5

args_epochs = 20

args_train_batch_size = 64
args_test_batch_size = 512
args_scoring_batch_size = 128

print('Applied Uncertainty Type: ', args_uncs)
# args_uncs = "GH"
# args_uncs = "HU"
# args_uncs = "DH"

def class_probs(data_loader):
    num_classes = 10
    class_n = len(data_loader.dataset)
    class_count = torch.zeros(num_classes)
    for data, label in data_loader:
        class_count += torch.Tensor([torch.sum(label == c) for c in range(num_classes)])

    class_prob = class_count / class_n
    return class_prob

if __name__ == "__main__":
    start_time = time.time()

    # args = al_args().parse_args()
    # print(args)

    # Checking if GPU is available
    cuda = torch.cuda.is_available()

    # Setting additional parameters
    torch.manual_seed(predefined_seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu") ## specify the GPU id's, GPU id's start from 0.

    model_fn = CreResNet18

    # Creating the datasets
    num_classes = 10
    train_dataset, test_dataset = create_MNIST_dataset()
    if enable_ambiguous:
        indices = np.random.choice(len(train_dataset), args_subsample)
        mnist_train_dataset = data.Subset(train_dataset, indices)
        train_dataset = data.ConcatDataset(
            [mnist_train_dataset, AmbiguousMNIST(root=dataset_root, train=True, device=device),]
        )

    # Creating a validation split
    idxs = list(range(len(train_dataset)))
    split = int(np.floor(0.1 * len(train_dataset)))
    np.random.seed(predefined_seed)
    np.random.shuffle(idxs)

    train_idx, val_idx = idxs[split:], idxs[:split]
    val_dataset = data.Subset(train_dataset, val_idx)
    train_dataset = data.Subset(train_dataset, train_idx)

    initial_sample_indices = active_learning.get_balanced_sample_indices(
        train_dataset, num_classes=num_classes, n_per_digit=args_num_initial_samples / num_classes,
    )

    kwargs = {"num_workers": 0, "pin_memory": False} if cuda else {}
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args_test_batch_size, shuffle=False, **kwargs)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args_test_batch_size, shuffle=False, **kwargs)

    # Run experiment
    num_runs = 1
    # num_runs = 1
    
    test_accs = {}
    ambiguous_dict = {}
    ambiguous_entropies_dict = {}

    for i in range(num_runs):
        test_accs[i] = []
        ambiguous_dict[i] = []
        ambiguous_entropies_dict[i] = {}

    for run in range(num_runs):
        print("Experiment run: " + str(run) + " =====================================================================>")

        torch.manual_seed(predefined_seed + run)

        # Setup data for the experiment
        # Split off the initial samples first
        active_learning_data = active_learning.ActiveLearningData(train_dataset)

        # Acquiring the first training dataset from the total pool. This is random acquisition
        active_learning_data.acquire(initial_sample_indices)

        # Train loader for the current acquired training set
        sampler = active_learning.RandomFixedLengthSampler(
            dataset=active_learning_data.training_dataset, target_length=5056
        )
        train_loader = torch.utils.data.DataLoader(
            active_learning_data.training_dataset, sampler=sampler, batch_size=args_train_batch_size, **kwargs,
        )

        small_train_loader = torch.utils.data.DataLoader(
            active_learning_data.training_dataset, shuffle=True, batch_size=args_train_batch_size, **kwargs,
        )

        # Pool loader for the current acquired training set
        pool_loader = torch.utils.data.DataLoader(
            active_learning_data.pool_dataset, batch_size=args_scoring_batch_size, shuffle=False, **kwargs,
        )

        # Run active learning iterations
        active_learning_iteration = 0
        while True:
            print("Active Learning Iteration: " + str(active_learning_iteration) + " ================================>")

            lr = 0.1
            weight_decay = 5e-4

            model_ensemble = [
                model_fn().to(device=device)
                for _ in range(args_num_ensemble)
            ]
            optimizers = []
            for model in model_ensemble:
                optimizers.append(torch.optim.Adam(model.parameters(), weight_decay=weight_decay))
                model.train()

            # Train
            print("Length of train dataset: " + str(len(train_loader.dataset)))
            best_model = None
            best_val_accuracy = 0
            for epoch in range(args_epochs):

                for (model, optimizer) in zip(model_ensemble, optimizers):
                    train_single_epoch(epoch, model, train_loader, optimizer, device)
                    
                # choose lower probs as reference
                _, val_accuracy, _, _, _ = test_classification_net_ensemble(model_ensemble, val_loader, device=device)
 
                if val_accuracy > best_val_accuracy:
                    best_val_accuracy = val_accuracy
                    best_model = model_ensemble 

            model_ensemble = best_model
            
            print("Training ended")

            # Testing the models
            print("Testing the model: Ensemble======================================>")
            for model in model_ensemble:
                model.eval()
            (conf_matrix, accuracy, labels_list, predictions, confidences,) = test_classification_net_ensemble(
                model_ensemble, test_loader, device=device
            )

            percentage_correct = 100.0 * accuracy
            test_accs[run].append(percentage_correct)

            print("Test set: Accuracy: ({:.2f}%)".format(percentage_correct))

            # Breaking clause
            if len(active_learning_data.training_dataset) >= args_max_training_samples:
                break

            # Acquisition phase
            N = len(active_learning_data.pool_dataset)

            print("Performing acquisition ========================================")

            for model in model_ensemble:
                model.eval()
            ensemble_uncs = []
            with torch.no_grad():
                for data, _ in pool_loader:
                    data = data.to(device)
                    mean_output, max_entropy, diff_entropy, gh = credal_ensemble_forward_pass(model_ensemble, data, uncertainty_type=args_uncs)
                    if args_uncs == "HU":
                        ensemble_uncs.append(max_entropy)
                    elif args_uncs == "GH":
                        ensemble_uncs.append(gh)
                    else:
                        ensemble_uncs.append(diff_entropy)
                        
                # Convert NumPy arrays to PyTorch tensors
                ensemble_uncs = [torch.from_numpy(arr) for arr in ensemble_uncs]
                        
                ensemble_uncs = torch.cat(ensemble_uncs, dim=0)

                (candidate_scores, candidate_indices,) = active_learning.get_top_k_scorers(
                    ensemble_uncs, args_acquisition_batch_size
                )

            # Performing acquisition
            active_learning_data.acquire(candidate_indices)
            active_learning_iteration += 1
            
    # Save the dictionaries
    end_time = time.time()
    print(end_time-start_time)
    
    save_name = model_save_name(args_model_name, predefined_seed)
    save_ensemble_mi = "_" + args_uncs
    if enable_ambiguous:
        accuracy_file_name = (
            "test_accs_max05_" + save_name + save_ensemble_mi + "_dirty_mnist_" + str(args_subsample) + ".json"
        )
    else:
        accuracy_file_name = "test_accs_" + save_name  + save_ensemble_mi + "_mnist.json"

    with open(accuracy_file_name, "w") as acc_file:
        json.dump(test_accs, acc_file)
