import os
from PIL import Image
from torchvision import transforms
import torch
from torch.utils.data import Dataset
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix
import seaborn as sns
import torch.nn as nn
import torchvision.models as models
import torchvision
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import matplotlib.patheffects as path_effects
import random
import pickle
from scipy.stats import wilcoxon
from scipy.stats import shapiro
from scipy import stats
from scipy.stats import ttest_rel
import numpy as np
from scipy.stats import linregress
from scipy.interpolate import make_interp_spline, BSpline
import matplotlib.pyplot as plt
from sklearn.model_selection import ParameterGrid
import matplotlib.pyplot as plt
from scipy.stats import pearsonr
import torch.nn.functional as F
from torch.autograd import Variable
from scipy.ndimage.filters import gaussian_filter1d
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import json
from model import *
from utils import *
import time
import itertools

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


weight_s = 0
weight_o = 0
weight_x = 1.5

def incremental_predict(model, bag, incremental_ratio=0.5):
    """
    Incrementally predicts using a model on subsets of a bag.
    """
    outputs = []
    total_len = bag.shape[1]
    index = math.ceil(total_len * incremental_ratio)

    for i in range(index, total_len + 1):
        # Extract subset of instances from the bag
        subset_bag = bag[:, :i, ...]
        output = model(subset_bag, total_len=bag.shape[1])
        outputs.append(output)

    return outputs

def construct_sequence(n, temperature=2.0):
    """
    Constructs a softmax-weighted sequence.
    """
    sequence = np.arange(1, n + 1)
    e_x = np.exp((sequence - np.max(sequence)) / temperature)
    softmax_sequence = e_x / e_x.sum()
    return softmax_sequence

def compute_weighted_incremental_loss(outputs, labels, softmax_sequence):
    """
    Computes the weighted incremental loss for a list of outputs.
    """
    loss = nn.BCELoss()
    losses = []

    for i, output in enumerate(outputs):
        instance_loss = loss(output, labels)
        weighted_instance_loss = instance_loss * softmax_sequence[i]
        losses.append(weighted_instance_loss)

    return sum(losses)





def USMIL_metric(bag_outputs_dict, ratio=1, weight_s = weight_s, weight_o = weight_o, temperature=2):   
    custom_ranges = {}
    for bag_id, instance_outputs in bag_outputs_dict.items():
        if len(instance_outputs) == 1:  # if bag length is 1 or 2
            U_S = np.mean([min(abs(x - 0), abs(x - 1)) for x in instance_outputs])
            U_O = instance_outputs[0] * (1-instance_outputs[0])
            USMIL = U_S * weight_s + U_O * weight_o
        else:  # if bag length is greater than 2
            U_S = np.std(instance_outputs)
            abs_diff = np.abs(np.diff(instance_outputs))
            weights_diff = construct_sequence(len(abs_diff), temperature=temperature)
            U_O = np.average(abs_diff, weights=weights_diff)
            USMIL = U_S * weight_s + U_O * weight_o
    return USMIL





def calculate_entropy(probabilities):
    """
    Calculate the entropy of a list of probabilities.
    """
    # Filter out zero probabilities as log2(0) is not defined
    probabilities = [p for p in probabilities if p > 0]
    entropy = -sum(p * np.log2(p) for p in probabilities)
    return entropy

def calculate_entropy_for_bags(bag_outputs_dict):
    """
    Calculate the entropy for each bag in the bag_outputs_dict.
    The function binarizes the outputs (greater than 0.5 to 1, otherwise 0)
    and then calculates the entropy.
    """
    entropy_dict = {}
    for bag_id, outputs in bag_outputs_dict.items():
        # Binarize the outputs
        binarized_outputs = [1 if output > 0.5 else 0 for output in outputs]
        # Calculate the frequency of each outcome (0 or 1)
        frequencies = np.bincount(binarized_outputs) / len(binarized_outputs)
        # Calculate the entropy
        entropy = calculate_entropy(frequencies)
        entropy_dict[bag_id] = entropy
    return entropy_dict




def incremental_predict_test(model, bag):
    # Initialize list to store outputs
    outputs = []

    # Compute total length of bag
    total_len = bag.shape[1]

    # Loop over instances in bag
    for i in range(1, total_len+1):
        # Extract subset of instances
        subset_bag = bag[:, :i, ...]
        # print(subset_bag.shape)
        # Compute bag feature
        # Here we assume that the model outputs the aggregated bag feature directly
        bag_feature = model(subset_bag)[-1]

        if len(bag_feature.shape) == 0:
            bag_feature = bag_feature.unsqueeze(0)
        # Get output
        # Now, we assume that the final activation is included in the model,
        # so we no longer need the classifier step
        output = bag_feature

        # Append output to list
        outputs.append(output.item())

    return outputs


def plot_uncertainty_histogram(exp_path, test_dataset, trans = True, uncertainty_metric=custom_range_std_convergence_metric_v6):
    # Load the model
    feature_extractor = EnhancedFeatureExtractor()
    model = BiSMIL()
    model.load_state_dict(torch.load(exp_path + "model_best.pth"))
    
    # Initialize test data loader
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)
    
    incremental_outputs = []  
    bag_id_list = []
    model.eval()  

    # Get incremental outputs for all instances in all bags
    with torch.no_grad():  
        for i, (bag, label, bag_id,  bag_seq_digits) in enumerate(tqdm(test_loader)):
            label = label.float()
            total_len = bag.shape[1]
            bag_incremental_outputs = incremental_predict_test(model, bag)
            # if trans:
            #     bag_incremental_outputs = incremental_predict(model, bag)
            # else:
            #     bag_incremental_outputs = incremental_predict(model, bag)
            
            incremental_outputs.append(bag_incremental_outputs)
            bag_id_list.append(bag_id)

    # Compute uncertainty scores
    bag_incremental_outputs_dict = {bag_id: inc_out for bag_id, inc_out in zip(bag_id_list, incremental_outputs)}
    uncertainty_scores = uncertainty_metric(bag_incremental_outputs_dict)

    # Get scores and calculate mean
    scores = list(uncertainty_scores.values())
    mean_scores = np.mean(scores)
    
    # Create histogram
    plt.hist(scores, bins=20, alpha=0.5, color='g')
    plt.axvline(mean_scores, color='r', linestyle='dashed', linewidth=2)
    plt.text(mean_scores, plt.ylim()[1] * 0.9, 'Mean: {:.2f}'.format(mean_scores), 
            horizontalalignment='center', fontsize=12, color='r')
    plt.title('Histogram of Uncertainty Metric')
    plt.xlabel('Value')
    plt.ylabel('Frequency')

    plt.tight_layout()
    plt.show()



def initialize_training(config):
    para = config['parameters']
    model_class = config["model_class"]  # 这是一个类，如 ModelA
    loss_func = nn.BCELoss()
    if "ADMIL" in para['experiment_name']:
        model = model_class()
    elif "One_Stream_Trans" in para['experiment_name']:
        model = model_class(feature_dim=config['feature_dim'], num_heads=para['num_heads'], 
        num_layers=para['num_layers'],ff_dim=para['ff_dim'], output_dim=1, dropout = para['dropout'],  clip_ratio= para['clip_ratio'])
    elif "Two_Stream_Trans" in para['experiment_name'] and (not "VGG" in para['experiment_name']):
        model = model_class(feature_dim=config['feature_dim'], num_heads=para['num_heads'], 
        num_layers=para['num_layers'],ff_dim=para['ff_dim'], output_dim=1, dropout = para['dropout'],  clip_ratio= para['clip_ratio'])
    elif "Two_Stream_Trans" in para['experiment_name'] and ("VGG" in para['experiment_name']):
        model = model_class(feature_dim=config['feature_dim'], num_heads=para['num_heads'], 
        num_layers=para['num_layers'],ff_dim=para['ff_dim'], output_dim=1, dropout = para['dropout'],  clip_ratio= para['clip_ratio'])
        print("Successfully load Trans_VGG Model")
    elif "SA_DMIL" in para['experiment_name']:
        model = SA_DMIL()
        loss_func = SmoothMIL(alpha=para['alpha_SADMIL'], S_k=1)
    else:
        raise("Not implemented Error")
    optimizer = torch.optim.Adam(model.parameters(), lr=para["learning_rate"], weight_decay=para["weight_decay"])
    model = model.to(device)

    return model, optimizer, loss_func


def load_model_for_seed(experiment_name, seed, configuration, base_path="/outputs/model_save_final_exp"):
    config = configuration.copy()
    config['parameters'] = config['parameters'].copy()
    config['parameters']['seed'] = seed
    model_exp_path = f"{experiment_name}_{seed}"

    model, _, _ = initialize_training(config)
    model_path = os.path.join(base_path, model_exp_path, "model_best.pth")

    model.load_state_dict(torch.load(model_path))
    print("Successfully load model at ", model_path)

    return model

def load_all_models(configurations_list):
    models_by_dataset = {"model_name": [], }

    for i, experiment_name in enumerate(models_by_dataset.keys()):
        for seed in range(5):  
            model = load_model_for_seed(experiment_name, seed, configurations_list[list(configurations_list)[i]])
            models_by_dataset[experiment_name].append(model)

    return models_by_dataset




configurations_list = {
    "Two_Stream_Trans_UTD": {
        "model_class": BiDMIL,
        "feature_dim" : 288,
        "parameters": {
            "num_epochs": 40,
            "seed": 4,
            "learning_rate": 1e-4,
            "weight_decay": 1e-4,
            "dropout": 0.2,
            "num_heads": 8,
            "num_layers": 2,
            "ff_dim":128,
            "incremental_training": True,
            "alpha":  0.5,
            "beta": 0.5,
            "clip_ratio": 0.6,
            "experiment_name": ""
        }},
    "Two_Stream_Trans_RSNA": {
        "model_class": BiDMIL,
        "feature_dim" : 288,
        "parameters": {
            "num_epochs": 60,
            "seed": 1,
            "learning_rate": 1e-4,
            "weight_decay": 1e-4,
            "dropout": 0.2,
            "num_heads": 8,
            "num_layers": 2,
            "ff_dim":128,
            "incremental_training": True,
            "alpha": 0.5,
            "beta": 0.5,
            "clip_ratio": 0.5,
            "experiment_name": ""
        }},
    "Two_Stream_Trans_Covid": {
        "model_class": BiDMIL,
        "feature_dim" : 288,
        "parameters": {
            "num_epochs": 40,
            "seed": 0,
            "learning_rate": 1e-4,
            "weight_decay": 1e-4,
            "dropout": 0.2,
            "num_heads": 8,
            "num_layers": 2,
            "ff_dim":128,
            "incremental_training": True,
            "alpha": 1,
            "beta": 0,
            "clip_ratio": 1,
            "experiment_name": ""
        }},
}


models_by_dataset = load_all_models(configurations_list)

for experiment_name, models in models_by_dataset.items():
    for i, model in enumerate(models):
        models_by_dataset[experiment_name][i] = model.to(device)

path_utd ='/MIL_dataset/data_final_exp/UTD'
test_dataset_utd = torch.load( path_utd+ '/UTD_test_dataset.pt')
test_loader_utd = DataLoader(test_dataset_utd, batch_size=1, shuffle=False)


path_rsna ='/MIL_dataset/data_final_exp/RSNA'
test_dataset_rsna = torch.load( path_rsna+ '/RSNA_test_dataset.pt')
test_loader_rsna = DataLoader(test_dataset_rsna, batch_size=1, shuffle=False)


path_covid ='/MIL_dataset/data_final_exp/Covid'
test_dataset_covid = torch.load( path_covid+ '/Covid_test_dataset.pt')
test_loader_covid = torch.load(path_covid + '/Covid_test_loader.pt')
print("Successfully load the data")



datasets = {
    "": test_loader_utd,
    "": test_loader_rsna,
    "": test_loader_covid
}

all_results = []


param_grid = {
    'ratio': [1, 0.5],
    'mid_mean_distance_weight': [1.0, 0.5],
    'mid_range_weight': [0.4, 0.2],
    'mid_std_weight': [0.4, 0.3],
    'convergence_speed_weight': [2.0, 1.0],
    'temperature1': [1.6, 1],
    'temperature2': [0.24, 0.1],
    'temperature3': [0.4, 0.3]
}

param_combinations = list(itertools.product(*param_grid.values()))

param_combinations = random.sample(param_combinations, 1)

default_params = {key: 'default' for key in param_grid.keys()}


all_results = []

for experiment_name, models in models_by_dataset.items():
    test_loader = datasets[experiment_name]

    if experiment_name != "":
        for params in param_combinations:
            param_dict = dict(zip(param_grid.keys(), params))
            continue

            for seed, model in enumerate(models):
                model_results = []
                model.to(device)
                model.eval()
                start_time = time.time()
                with torch.no_grad():
                    for i, (bag, label, bag_id, bag_seq_digits) in enumerate(test_loader):
                        bag = bag.to(device)
                        label = label.to(device)

                        outputs = incremental_predict_test(model, bag)

                        uncertainty = custom_range_std_convergence_metric_v6({bag_id[0]: outputs}, **param_dict)
                        # uncertainty = calculate_entropy_for_bags({bag_id[0]: outputs})
                        entropy_uncertainty = calculate_entropy_for_bags({bag_id[0]: outputs})

                        USMIL = USMIL_metric({bag_id[0]: outputs})
                        USMIL_new = USMIL_metric_new({bag_id[0]: outputs})
                        

                        model_results.append({
                            "Bag ID": bag_id[0],
                            "True Label": label.item(),
                            "Predicted Label": outputs[-1],
                            "Incremental Predictions": outputs,
                            "Uncertainty": uncertainty[bag_id[0]],
                            "Entropy_Uncertainty":entropy_uncertainty[bag_id[0]],
                            "USMIL_Uncertainty" : USMIL,
                            "USMIL_new" : USMIL_new,
                        })

                end_time = time.time()
                test_duration = end_time - start_time

                print(f"Experiment: {experiment_name}, Params: {param_dict}, Seed: {seed}, Test Duration: {test_duration:.2f} seconds")

                all_results.append({
                    "Experiment": experiment_name,
                    "Params": param_dict,
                    "Seed": seed,
                    "Results": model_results
                })

    else:
        for seed, model in enumerate(models):
            model_results = []
            model.to(device)
            model.eval()
            start_time = time.time()
            with torch.no_grad():
                for i, (bag, label, bag_id, bag_seq_digits) in enumerate(test_loader):
                    bag = bag.to(device)
                    label = label.to(device)

                    outputs = incremental_predict_test(model, bag)

                    uncertainty = custom_range_std_convergence_metric_v6({bag_id[0]: outputs})
                    entropy_uncertainty = calculate_entropy_for_bags({bag_id[0]: outputs})
                    USMIL = USMIL_metric({bag_id[0]: outputs})

                    model_results.append({
                        "Bag ID": bag_id[0],
                        "True Label": label.item(),
                        "Predicted Label": outputs[-1],
                        "Incremental Predictions": outputs,
                        "Uncertainty": uncertainty[bag_id[0]],
                        "Entropy_Uncertainty":entropy_uncertainty[bag_id[0]],
                        "USMIL_Uncertainty" : USMIL,
                    })

            end_time = time.time()
            test_duration = end_time - start_time

            print(f"Experiment: {experiment_name}, Seed: {seed}, Test Duration: {test_duration:.2f} seconds")

            all_results.append({
                "Experiment": experiment_name,
                "Params": default_params,
                "Seed": seed,
                "Results": model_results
            })

df = pd.DataFrame(all_results)
output_file = os.path.join('/outputs/uncertainty_tune', f'Uncertain_results__USMIL_with_output_uncertain_s_{weight_s}_o_{weight_o}_x_{weight_x}.pkl')
df.to_pickle(output_file)

