import os
from PIL import Image
from torchvision import transforms
import torch
from torch.utils.data import Dataset
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix
import seaborn as sns
import torch.nn as nn
import torchvision.models as models
import torchvision
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import matplotlib.patheffects as path_effects
import random
import pickle
from scipy.stats import wilcoxon
from scipy.stats import shapiro
from scipy import stats
from scipy.stats import ttest_rel
import numpy as np
from scipy.stats import linregress
from scipy.interpolate import make_interp_spline, BSpline
import matplotlib.pyplot as plt
from sklearn.model_selection import ParameterGrid
import matplotlib.pyplot as plt
from scipy.stats import pearsonr
import torch.nn.functional as F
from torch.autograd import Variable
from scipy.ndimage.filters import gaussian_filter1d
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import json
from model import *
from utils import *
import time
import itertools

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


def construct_sequence(n, temperature=2.0):
    """
    Constructs a softmax-weighted sequence.
    """
    sequence = np.arange(1, n + 1)
    e_x = np.exp((sequence - np.max(sequence)) / temperature)
    softmax_sequence = e_x / e_x.sum()
    return softmax_sequence


def incremental_predict_test(model, bag):
    # Initialize list to store outputs
    outputs = []

    # Compute total length of bag
    total_len = bag.shape[1]

    # Loop over instances in bag
    for i in range(1, total_len+1):
        # Extract subset of instances
        subset_bag = bag[:, :i, ...]
        # print(subset_bag.shape)
        # Compute bag feature
        # Here we assume that the model outputs the aggregated bag feature directly
        bag_feature = model(subset_bag)[-1]

        if len(bag_feature.shape) == 0:
            bag_feature = bag_feature.unsqueeze(0)
        # Get output
        # Now, we assume that the final activation is included in the model,
        # so we no longer need the classifier step
        output = bag_feature

        # Append output to list
        outputs.append(output.item())

    return outputs



def initialize_training(config):
    para = config['parameters']
    model_class = config["model_class"]  # 
    loss_func = nn.BCELoss()
    if "ADMIL" in para['experiment_name']:
        model = model_class()
    elif "One_Stream_Trans" in para['experiment_name']:
        model = model_class(feature_dim=config['feature_dim'], num_heads=para['num_heads'], 
        num_layers=para['num_layers'],ff_dim=para['ff_dim'], output_dim=1, dropout = para['dropout'],  clip_ratio= para['clip_ratio'])
    elif "Two_Stream_Trans" in para['experiment_name'] and (not "VGG" in para['experiment_name']):
        model = model_class(feature_dim=config['feature_dim'], num_heads=para['num_heads'], 
        num_layers=para['num_layers'],ff_dim=para['ff_dim'], output_dim=1, dropout = para['dropout'],  clip_ratio= para['clip_ratio'])
    elif "Two_Stream_Trans" in para['experiment_name'] and ("VGG" in para['experiment_name']):
        model = model_class(feature_dim=config['feature_dim'], num_heads=para['num_heads'], 
        num_layers=para['num_layers'],ff_dim=para['ff_dim'], output_dim=1, dropout = para['dropout'],  clip_ratio= para['clip_ratio'])
        print("Successfully load Trans_VGG Model")
    elif "SA_DMIL" in para['experiment_name']:
        model = SA_DMIL()
        loss_func = SmoothMIL(alpha=para['alpha_SADMIL'], S_k=1)
    else:
        raise("Not implemented Error")
    optimizer = torch.optim.Adam(model.parameters(), lr=para["learning_rate"], weight_decay=para["weight_decay"])
    model = model.to(device)

    return model, optimizer, loss_func


def load_model_for_seed(experiment_name, seed, configuration, base_path="/data/lxl213/outputs/model_save_final_exp"):
    config = configuration.copy()
    config['parameters'] = config['parameters'].copy()
    config['parameters']['seed'] = seed
    model_exp_path = f"{experiment_name}_{seed}"

    model, _, _ = initialize_training(config)
    model_path = os.path.join(base_path, model_exp_path, "model_best.pth")
    print(experiment_name, '\n',  ("ADMIL_MAX" in str(experiment_name)))
    if "ADMIL_MAX" in str(experiment_name):
        model_path = os.path.join(base_path, model_exp_path, "model_best_val.pth")

    model.load_state_dict(torch.load(model_path))
    print("Successfully load model at ", model_path)

    return model

def load_all_models(configurations_list):
    models_by_dataset = {"1206_UTD_MIL_Trans_One_Stream_Trans_New1_Incremental_Inception_Epoch60_clip0.6": [], 
    "1123_ADMIL": [], 
    "1125_UTD_SA_DMIL": [],
    "0201_ADMIL_MAX_UTD": []}

    for i, experiment_name in enumerate(models_by_dataset.keys()):
        for seed in range(5):  
            model = load_model_for_seed(experiment_name, seed, configurations_list[list(configurations_list)[i]])
            models_by_dataset[experiment_name].append(model)

    return models_by_dataset


configurations_list = {
    "One_Stream_Trans_UTD": {
        "model_class": MIL_Trans_One_Stream_New1_Incremental_Inception,
        "feature_dim" : 288,
        "parameters": {
            "num_epochs": 40,
            "seed": 4,
            "learning_rate": 1e-4,
            "weight_decay": 1e-4,
            "dropout": 0.2,
            "num_heads": 8,
            "num_layers": 2,
            "ff_dim":128,
            "incremental_training": True,
            "alpha":  0.5,
            "beta": 0.5,
            "clip_ratio": 0.6,
            "experiment_name": "1206_UTD_MIL_Trans_One_Stream_Trans_New1_Incremental_Inception_Epoch60_clip0.6_4"
        }},
    "ADMIL_Baseline": {
        "model_class": ADMIL,
        "parameters": {
            "num_epochs": 2,
            "seed": 3,
            "learning_rate": 0.0005,
            "weight_decay": 1e-4,
            "dropout": 0.0,
            "num_heads": 8,
            "num_layers": 6,
            "incremental_training": False,
            "alpha": 0.5,
            "beta": 0.5,
            "experiment_name": "1123_ADMIL_test"
        }},
        "SA_DMIL": {
        "model_class": SA_DMIL,
        "parameters": {
            "num_epochs": 30,
            "seed": 4,
            "learning_rate": 1e-4,
            "weight_decay": 1e-4,
            "dropout": 0.1,
            "num_heads": 8,
            "num_layers": 6,
            "ff_dim":512,
            "incremental_training": False,
            "alpha_SADMIL": 0.5,
            "beta": 0.5,
            "experiment_name": "1125_UTD_SA_DMIL_4"
        }},
        "ADMIL_Max": {
        "model_class": MIL_Max,
        "parameters": {
            "num_epochs": 40,
            "seed": 4,
            "learning_rate": 0.0005,
            "weight_decay": 1e-4,
            "dropout": 0.0,
            "num_heads": 8,
            "num_layers": 6,
            "incremental_training": False,
            "alpha": 0.5,
            "beta": 0.5,
            "experiment_name": "0201_ADMIL_MAX_UTD_4"
        }
    },
}


models_by_dataset = load_all_models(configurations_list)

for experiment_name, models in models_by_dataset.items():
    for i, model in enumerate(models):
        models_by_dataset[experiment_name][i] = model.to(device)

path_utd ='/MIL_dataset/data_final_exp/UTD'
test_dataset_utd = torch.load( path_utd+ '/UTD_test_dataset.pt')
test_loader_utd = DataLoader(test_dataset_utd, batch_size=1, shuffle=False)
print("Successfully load the data")


def crop_and_test(crop_percentage, model, test_loader, experiment_name):
    correct_test = 0
    total_test = 0
    model.eval()
    with torch.no_grad():
        for i, (bag, label, bag_id, bag_seq_digits) in enumerate(test_loader):
            num_instances = bag.shape[1]
            num_crop = max(1, int(num_instances * (1 - crop_percentage)))
            bag = bag[:, :num_crop, ...].to(device)
            label = label.float().to(device)
            if "One_Stream" in experiment_name:
                output = model(bag)[-1]
                if len(output.shape) == 0:
                    output = output.unsqueeze(0)
            elif "SA_DMIL" in experiment_name:
                output, att_weights = model(bag)
            else:
                output = model(bag, total_len=bag.shape[1])
            predicted_test = (output > 0.5).float()
            total_test += label.size(0)
            correct_test += (predicted_test == label).sum().item()
    return correct_test / total_test if total_test > 0 else 0

def reverse_and_test(model, test_loader, experiment_name):
    correct_test = 0
    total_test = 0
    model.eval()
    with torch.no_grad():
        for i, (bag, label, bag_id, bag_seq_digits) in enumerate(test_loader):
            if bag.shape[1] > 1:
                bag = bag.flip(dims=[1])
            bag = bag.to(device)
            label = label.float().to(device)
            # output = model(bag, total_len = bag.shape[1])
            if "One_Stream" in experiment_name:
                output = model(bag)[-1]
                if len(output.shape) == 0:
                    output = output.unsqueeze(0)
            elif "SA_DMIL" in experiment_name:
                output, att_weights = model(bag)
            else:
                output = model(bag, total_len=bag.shape[1])
            predicted_test = (output > 0.5).float()
            total_test += label.size(0)
            correct_test += (predicted_test == label).sum().item()
    return correct_test / total_test if total_test > 0 else 0


crop_percentages = np.linspace(0, 0.8, 9)  
model_accuracies = []

for experiment_name, models in models_by_dataset.items():
    for seed, model in enumerate(models):
        accuracies = []
        for crop_percentage in crop_percentages:
            accuracy = crop_and_test(crop_percentage, model, test_loader_utd, experiment_name)
            accuracies.append(accuracy)
        reverse_accuracy = reverse_and_test(model, test_loader_utd, experiment_name)
        accuracies.append(reverse_accuracy)
        model_accuracies.append({
            "Experiment": experiment_name,
            "Seed": seed,
            "Accuracies": accuracies
        })



output_file = '/outputs/model_uncertainty/crop_reverse_accuracies_UTD_0201_correct.pkl'
pd.to_pickle(model_accuracies , output_file)

print("Successfully computed and saved accuracies for crop and reverse tests.")


