import json
import torch
import numpy as np
import torch.backends.cudnn as cudnn
import time

# Import data utilities
import torch.utils.data as data
import data.active_learning.active_learning as active_learning
from data.ambiguous_mnist.ambiguous_mnist_dataset import AmbiguousMNIST
from data.fast_mnist import create_MNIST_dataset

# Import network architectures
from models.dnn_res18 import ResNet18

# Import train and test utils
from utils.train_utils import train_single_epoch, model_save_name

# Importing uncertainty metrics
from metrics.classification_metrics import test_classification_net_ensemble


from utils.ensemble_utils import ensemble_forward_pass


args_model_name = "dnn_resnet18"
predefined_seed = 0
# predefined_seed = 1
# predefined_seed = 2
# predefined_seed = 3
# predefined_seed = 4

# enable_ambiguous = False
enable_ambiguous = True
# dataset_root = "./"
dataset_root = "data/ambiguous_mnist/"
args_al_type = "ensemble"
args_num_ensemble = 5


args_threshold = 1.0
args_subsample = 1000

args_num_initial_samples = 20
args_max_training_samples = 300
args_acquisition_batch_size = 5

args_epochs = 20

args_train_batch_size = 64
args_test_batch_size = 512
args_scoring_batch_size = 128


args_uncs = "MI"

def class_probs(data_loader):
    num_classes = 10
    class_n = len(data_loader.dataset)
    class_count = torch.zeros(num_classes)
    for data, label in data_loader:
        class_count += torch.Tensor([torch.sum(label == c) for c in range(num_classes)])

    class_prob = class_count / class_n
    return class_prob


if __name__ == "__main__":
    start_time = time.time()

    # Checking if GPU is available
    cuda = torch.cuda.is_available()

    # Setting additional parameters
    torch.manual_seed(predefined_seed)
    # device = torch.device("cuda" if cuda else "cpu")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu") ## specify the GPU id's, GPU id's start from 0.
    
    # Check if CUDA (GPU support) is available
    if torch.cuda.is_available():
        # Set the device to GPU
        device = torch.device("cuda")
    
        # Get the GPU device count
        gpu_count = torch.cuda.device_count()
    
        # Print GPU information
        for i in range(gpu_count):
            print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
    
    else:
        # If CUDA is not available, use CPU
        device = torch.device("cpu")
    
    print(f"Selected device: {device}")


    model_fn = ResNet18

    # Creating the datasets
    num_classes = 10
    train_dataset, test_dataset = create_MNIST_dataset()
    if enable_ambiguous:
        indices = np.random.choice(len(train_dataset), args_subsample)
        mnist_train_dataset = data.Subset(train_dataset, indices)
        train_dataset = data.ConcatDataset(
            [mnist_train_dataset, AmbiguousMNIST(root=dataset_root, train=True, device=device),]
        )

    # Creating a validation split
    idxs = list(range(len(train_dataset)))
    split = int(np.floor(0.1 * len(train_dataset)))
    np.random.seed(predefined_seed)
    np.random.shuffle(idxs)

    train_idx, val_idx = idxs[split:], idxs[:split]
    val_dataset = data.Subset(train_dataset, val_idx)
    train_dataset = data.Subset(train_dataset, train_idx)

    initial_sample_indices = active_learning.get_balanced_sample_indices(
        train_dataset, num_classes=num_classes, n_per_digit=args_num_initial_samples / num_classes,
    )

    kwargs = {"num_workers": 0, "pin_memory": False} if cuda else {}
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args_test_batch_size, shuffle=False, **kwargs)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args_test_batch_size, shuffle=False, **kwargs)

    # Run experiment
    num_runs = 1
    
    test_accs = {}
    ambiguous_dict = {}
    ambiguous_entropies_dict = {}

    for i in range(num_runs):
        test_accs[i] = []
        ambiguous_dict[i] = []
        ambiguous_entropies_dict[i] = {}

    for run in range(num_runs):
        print("Experiment run: " + str(run) + " =====================================================================>")

        torch.manual_seed(predefined_seed + run)

        # Setup data for the experiment
        # Split off the initial samples first
        active_learning_data = active_learning.ActiveLearningData(train_dataset)

        # Acquiring the first training dataset from the total pool. This is random acquisition
        active_learning_data.acquire(initial_sample_indices)

        # Train loader for the current acquired training set
        sampler = active_learning.RandomFixedLengthSampler(
            dataset=active_learning_data.training_dataset, target_length=5056
        )
        train_loader = torch.utils.data.DataLoader(
            active_learning_data.training_dataset, sampler=sampler, batch_size=args_train_batch_size, **kwargs,
        )

        small_train_loader = torch.utils.data.DataLoader(
            active_learning_data.training_dataset, shuffle=True, batch_size=args_train_batch_size, **kwargs,
        )

        # Pool loader for the current acquired training set
        pool_loader = torch.utils.data.DataLoader(
            active_learning_data.pool_dataset, batch_size=args_scoring_batch_size, shuffle=False, **kwargs,
        )

        # Run active learning iterations
        active_learning_iteration = 0
        while True:
            print("Active Learning Iteration: " + str(active_learning_iteration) + " ================================>")

            lr = 0.1
            weight_decay = 5e-4

            model_ensemble = [
                model_fn().to(device=device)
                for _ in range(args_num_ensemble)
            ]
            optimizers = []
            for model in model_ensemble:
                optimizers.append(torch.optim.Adam(model.parameters(), weight_decay=weight_decay))
                model.train()

            # Train
            print("Length of train dataset: " + str(len(train_loader.dataset)))
            best_model = None
            best_val_accuracy = 0
            for epoch in range(args_epochs):

                for (model, optimizer) in zip(model_ensemble, optimizers):
                    train_single_epoch(epoch, model, train_loader, optimizer, device=device)
                    
                # choose lower probs as reference
                _, val_accuracy, _, _, _ = test_classification_net_ensemble(model_ensemble, val_loader, device=device)
 
                if val_accuracy > best_val_accuracy:
                    best_val_accuracy = val_accuracy
                    best_model = model_ensemble 

            model_ensemble = best_model
            
            print("Training ended")

            # Testing the models
            print("Testing the model: Ensemble======================================>")
            for model in model_ensemble:
                model.eval()
            (conf_matrix, accuracy, labels_list, predictions, confidences,) = test_classification_net_ensemble(
                model_ensemble, test_loader, device=device
            )

            percentage_correct = 100.0 * accuracy
            test_accs[run].append(percentage_correct)

            print("Test set: Accuracy: ({:.2f}%)".format(percentage_correct))

            # Breaking clause
            if len(active_learning_data.training_dataset) >= args_max_training_samples:
                break

            # Acquisition phase
            N = len(active_learning_data.pool_dataset)

            print("Performing acquisition ========================================")

            for model in model_ensemble:
                model.eval()
            ensemble_uncs = []
            with torch.no_grad():
                for data, _ in pool_loader:
                    data = data.to(device)
                    mean_output, predictive_entropy, mi = ensemble_forward_pass(model_ensemble, data)
                    if args_uncs == "MI":
                        ensemble_uncs.append(mi)
                    else:
                        ensemble_uncs.append(predictive_entropy)
                        
                # # Convert NumPy arrays to PyTorch tensors
                # ensemble_uncs = [torch.from_numpy(arr) for arr in ensemble_uncs]
                        
                ensemble_uncs = torch.cat(ensemble_uncs, dim=0)

                (candidate_scores, candidate_indices,) = active_learning.get_top_k_scorers(
                    ensemble_uncs, args_acquisition_batch_size
                )

            # Performing acquisition
            active_learning_data.acquire(candidate_indices)
            active_learning_iteration += 1
            
    # Save the dictionaries
    end_time = time.time()
    print(start_time-end_time)
    
    save_name = model_save_name(args_model_name, predefined_seed)
    save_ensemble_mi = "_" + args_uncs
    if enable_ambiguous:
        accuracy_file_name = (
            "test_accs_" + save_name + save_ensemble_mi + "_dirty_mnist_" + str(args_subsample) + ".json"
        )
    else:
        accuracy_file_name = "test_accs_" + save_name  + save_ensemble_mi + "_mnist.json"

    with open(accuracy_file_name, "w") as acc_file:
        json.dump(test_accs, acc_file)