import os

# os.environ["MKL_VERBOSE"] = "0"
os.environ["MKL_DEBUG_CPU_TYPE"] = "5"
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module='numpy')

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
from collections import defaultdict
import numpy as np
from tqdm import tqdm
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import torch.nn.functional as F
import random
from scipy.stats import entropy
from sklearn.metrics import mutual_info_score
from torch.utils.data import DataLoader, Subset

import torch
import matplotlib.pyplot as plt
import matplotlib
from PIL import Image
import math

from train_linear_model import linear_model
from utils import NoisyMNISTDataset
from utils import compute_layer_averages, save_layer_averages
from train_conv_model import CatDogCNN
from torch.nn.functional import cosine_similarity
from conv_neuronization import cal_strength

def get_layer_output(layer, input, output):
    global layer_output
    layer_output = output

def calculate_sensitivity_on_input(model, target_layer, neuron_indices):
    global layer_output
    for data, target in tqdm(test_loader):
        
        data.requires_grad_(True)
        data = data.view(data.size(0), -1)
        print(data)
        if target_layer == 'fc1':
            hook = model.fc1.register_forward_hook(get_layer_output)
        if target_layer == 'fc2':
            hook = model.fc2.register_forward_hook(get_layer_output)
        if target_layer == 'fc3':
            hook = model.fc3.register_forward_hook(get_layer_output)
        
        output = model(data)
        hook.remove()
        jacobian_matrix = []
        print(layer_output[0, 1])
        for neuron_index in neuron_indices:
            
            model.zero_grad()
            neuron_output = layer_output[0, neuron_index]
            neuron_output.backward(retain_graph=True)
            jacobian_matrix.append(data.grad.detach().clone().flatten())
            data.grad.zero_()

        jacobian_matrix = torch.stack(jacobian_matrix)

        cos = nn.CosineSimilarity(dim=1, eps=1e-6)
        similarity_matrix = cos(jacobian_matrix.unsqueeze(0), jacobian_matrix.unsqueeze(1))

        cos = nn.CosineSimilarity(dim=0)
        n = len(neuron_indices)
        total_similarity = 0.0

        for i in range(n):
            for j in range(i + 1, n):
                similarity = cos(jacobian_matrix[i], jacobian_matrix[j])
                total_similarity += similarity.item()

        average_similarity = total_similarity / (n * (n - 1) / 2)
        similarities.append(average_similarity)

    overall_similarity = sum(similarities) / len(similarities)
    print(f"Overall average similarity between neurons: {overall_similarity}")

def calculate_sensitivity_on_input_jacobian(model, target_layer, test_loader):
    global layer_output
    
    accumulated_jacobian = None
    num_samples = 0
    for data_mnist, target in test_loader:
        
        # data_mnist.requires_grad_(True)
        data_mnist = data_mnist.view(data_mnist.size(0), -1).requires_grad_(True)

        if target_layer == 'fc1':
            hidden_layer_output = model.get_hidden_layer_fc1_output(data_mnist)
        elif target_layer == 'fc2':
            hidden_layer_output = model.get_hidden_layer_fc2_output(data_mnist)
        elif target_layer == 'fc3':
            hidden_layer_output = model.get_hidden_layer_fc3_output(data_mnist)
        
        jacobian_matrix = torch.zeros([hidden_layer_output.size()[1], data_mnist.size()[1]])

        for i in range(hidden_layer_output.size()[1]):
            # hidden_layer_output_mean = torch.mean(hidden_layer_output, dim=0).reshape(1, hidden_layer_output.shape[1])
            # print(hidden_layer_output_mean)
            hidden_layer_output_mean = hidden_layer_output
            model.zero_grad()
            hidden_layer_output_mean[:, i].backward(retain_graph=True)
            jacobian_matrix[i] = data_mnist.grad.data
            data_mnist.grad.data.zero_()
        
        sub_jacobian_matrix = jacobian_matrix@jacobian_matrix.T
            
        if accumulated_jacobian is None:
            accumulated_jacobian = sub_jacobian_matrix
        else:
            accumulated_jacobian += sub_jacobian_matrix
        
        num_samples += 1

    mean_jacobian = accumulated_jacobian / num_samples
    return mean_jacobian

def calculate_sensitivity_on_output_jacobian(model, target_layer, test_loader):
    accumulated_jacobian = None
    num_samples = 0
    
    for data_mnist, target in test_loader:
        data_mnist = data_mnist.view(data_mnist.size(0), -1).requires_grad_(True)
        
        if target_layer == 'fc1':
            hidden_output = model.get_hidden_layer_fc1_output(data_mnist)
        elif target_layer == 'fc2':
            hidden_output = model.get_hidden_layer_fc2_output(data_mnist)
        elif target_layer == 'fc3':
            hidden_output = model.get_hidden_layer_fc3_output(data_mnist)
        else:
            raise ValueError(f"Unknown target layer: {target_layer}")
        
        output = model(data_mnist)
        jacobian_matrix = torch.zeros([hidden_output.size()[1], output.size()[1]])
        
        for i in range(output.shape[1]):  
            model.zero_grad()
            output[:, i].backward(retain_graph=True)  # Compute gradients for the i-th output
                
            # Extract the gradients for the hidden layer outputs
            jacobian_i = hidden_output.grad
            jacobian_matrix[i] = jacobian_i
        sub_jacobian_matrix = batch_jacobian@batch_jacobian.T
        
        if accumulated_jacobian is None:
            accumulated_jacobian = sub_jacobian_matrix
        else:
            accumulated_jacobian += sub_jacobian_matrix
        
        num_samples += 1

    mean_jacobian = accumulated_jacobian / num_samples
    
    return mean_jacobian

def calculate_statistics(data, mean):
    deviations = data - mean 
    std_dev = torch.std(deviations, unbiased=True)
    ncs = 1 / (1 + std_dev)
    consistency = torch.sum(torch.abs(deviations) * torch.sign(deviations)) / (torch.sum(torch.abs(deviations))+torch.tensor(10^(-6)))
    return ncs * torch.abs(consistency)

def weighted_statistics(layer_output, average_output, communities):
    community_stats = []
    weights = []
    
    for community in communities:
        community_data = layer_output[community]
        community_mean = average_output[community]
        stat = calculate_statistics(community_data, community_mean)
        community_stats.append(stat)
        weights.append(community_data.numel())
        
    community_stats = torch.tensor(community_stats)
    weights = torch.tensor(weights, dtype=community_stats.dtype)
    
    weighted_average = torch.sum(community_stats * weights) / torch.sum(weights)
    return weighted_average

def calculate_average_from_matrix(matrix, index_list):
    if len(index_list) <= 2:
        return 0
    else:
        result_sum = 0
        count = 0
        
        for i in range(len(index_list)):
            for j in range(i + 1, len(index_list)):
                result_sum += matrix[index_list[i], index_list[j]]
                count += 1
        average = result_sum / count
        return average

def estimate_modularity(q_kc, target_layer, model, test_loader, mean_jacobian):
    '''
    # step1: calculate sensitivity on iutput based on jacobian (score increase, sensitivity increase)
    mean_jacobian = calculate_sensitivity_on_input_jacobian(model, target_layer, test_loader)
    '''
    sensitivity_on_input = 0
    
    for k in range(len(q_kc)):
        sensitivity_on_input += calculate_average_from_matrix(mean_jacobian, q_kc[k])
    
    # step2: calculate similarity on output (score increase, similarity increase)
    # step2.1: loading mean value
    if target_layer == 'fc1':
        mean_feature_value = torch.load('weights/linear_each_layer_output/fc1_average.pth')
    elif target_layer == 'fc2':
        mean_feature_value = torch.load('weights/linear_each_layer_output/fc2_average.pth')
    elif target_layer == 'fc3':
        mean_feature_value = torch.load('weights/linear_each_layer_output/fc3_average.pth')
    else:
        raise ValueError(f"Unknown target layer: {target_layer}")
    
    # step2.2: get feature from linear model based on test_loader and calculate mean weights
    weighted_average = 0
    for data_mnist, target in test_loader:
        data_mnist = data_mnist.view(data_mnist.size(0), -1).requires_grad_(True)
        output = model(data_mnist)
        
        if target_layer == 'fc1':
            hidden_output = model.get_hidden_layer_fc1_output(data_mnist)
        elif target_layer == 'fc2':
            hidden_output = model.get_hidden_layer_fc2_output(data_mnist)
        elif target_layer == 'fc3':
            hidden_output = model.get_hidden_layer_fc3_output(data_mnist)
        else:
            raise ValueError(f"Unknown target layer: {target_layer}")

        weighted_average += weighted_statistics(hidden_output.squeeze(), mean_feature_value, q_kc)
    
    weighted_average_score = weighted_average/len(test_loader)
    final_score = sensitivity_on_input/30+weighted_average_score
    
    return final_score

layer_activations = None
def hook(module, input, output):
    global layer_activations
    layer_activations = output

def get_sensitivity(model, target_layer, kernel_index, input_tensor):
    
    input_tensor.requires_grad = True
    if input_tensor.grad is not None:
        input_tensor.grad.zero_()
        
    if target_layer == 'conv1':     
        hook_handle = model.conv1.register_forward_hook(hook)
        model(input_tensor)
        hook_handle.remove()
    elif target_layer == 'conv2':     
        hook_handle = model.conv2.register_forward_hook(hook)
        model(input_tensor)
        hook_handle.remove()
    elif target_layer == 'conv3':     
        hook_handle = model.conv3.register_forward_hook(hook)
        model(input_tensor)
        hook_handle.remove()
    elif target_layer == 'conv4':     
        hook_handle = model.conv4.register_forward_hook(hook)
        model(input_tensor)
        hook_handle.remove()
        
    conv_kernel_output = layer_activations[:, kernel_index, :, :]
    conv_kernel_output.backward(conv_kernel_output)
    
    sensitivity_map = input_tensor.grad.data.clone()
    sensitivity_map = sensitivity_map[0, 0]
    
    vec = sensitivity_map.view(-1)
    return vec

def calculate_sensitivity_similarity_between_two_neurons(model, target_layer,i,j, input_tensor):
    
    vec1 = get_sensitivity(model, target_layer, i, input_tensor)
    vec2 = get_sensitivity(model, target_layer, j, input_tensor)
    
    similarity = cosine_similarity(vec1.unsqueeze(0), vec2.unsqueeze(0))
    sign = similarity/abs(similarity)
    strength = cal_strength(vec1, vec2) 
    
    return sign*strength

def calculate_sensitivity_on_input_conv_single_image(model, target_layer, q_kc, input_tensor):
    
    len_communities = len(q_kc)
    sensitivity_similarity_q_kc = 0
    for k in range(len(q_kc)):
        sensitivity_similarity_community = 0
        count = 0
        community = q_kc[k]

        for i in range(len(q_kc[k])):
            for j in range(i + 1, len(q_kc[k])):
                sensitivity_similarity_community += calculate_sensitivity_similarity_between_two_neurons(model, target_layer,i,j, input_tensor)
                count += 1
                
        average_sensitivity_similarity_community = sensitivity_similarity_community / count
    
    sensitivity_similarity_q_kc += average_sensitivity_similarity_community
    sensitivity_similarity_q_kc = sensitivity_similarity_q_kc/len_communities
    
    return sensitivity_similarity_q_kc

def calculate_sensitivity_on_input_conv_single_dataloader(model, target_layer, q_kc, test_loader):
    
    sensitivity_similarity_q_kc_test_dataloader = 0
    count = 0
    for data_afhq in test_loader:
        input_tensor = data_afhq[0]
        sensitivity_value = calculate_sensitivity_on_input_conv_single_image(model, target_layer, q_kc, input_tensor)
        if math.isnan(sensitivity_value):
            pass
        else:
            sensitivity_similarity_q_kc_test_dataloader += calculate_sensitivity_on_input_conv_single_image(model, target_layer, q_kc, input_tensor)
            count += 1
        
    sensitivity_similarity_q_kc_test_dataloader = sensitivity_similarity_q_kc_test_dataloader/count
    
    return sensitivity_similarity_q_kc_test_dataloader

def calculate_statistics_conv(data, mean):
    deviations = data - mean 
    std_dev = torch.std(deviations, unbiased=True)
    ncs = 1 / (1 + std_dev)
    consistency = torch.sum(torch.abs(deviations) * torch.sign(deviations)) / (torch.sum(torch.abs(deviations)) + torch.tensor(1e-6))
    return ncs * torch.abs(consistency)

def weighted_statistics_conv(layer_output, average_output, communities):
    community_stats = []
    weights = []

    for community in communities:
        community_data = layer_output[:, community, :, :]
        community_mean = average_output[:, community, :, :]

        community_data_flatten = community_data.flatten()
        community_mean_flatten = community_mean.flatten()
        stat = calculate_statistics(community_data_flatten, community_mean_flatten)
        community_stats.append(stat)
        weights.append(community_data_flatten.numel())

    community_stats = torch.tensor(community_stats)
    weights = torch.tensor(weights, dtype=community_stats.dtype)

    weighted_average = torch.sum(community_stats * weights) / torch.sum(weights)
    return weighted_average

def calculate_average_feature_conv(model, test_loader):
    conv_1_feature_map = 0
    conv_2_feature_map = 0
    conv_3_feature_map = 0
    conv_4_feature_map = 0
    count = 0
    for data_afhq in tqdm(test_loader):
        input_tensor = data_afhq[0]
        conv_1_feature_map += model.get_hidden_layer_conv1_output(input_tensor)
        conv_2_feature_map += model.get_hidden_layer_conv2_output(input_tensor)
        conv_3_feature_map += model.get_hidden_layer_conv3_output(input_tensor)
        conv_4_feature_map += model.get_hidden_layer_conv4_output(input_tensor)
        
        count += 1
        
    average_conv_1_feature_map = conv_1_feature_map/count
    average_conv_2_feature_map = conv_2_feature_map/count
    average_conv_3_feature_map = conv_3_feature_map/count
    average_conv_4_feature_map = conv_4_feature_map/count
    
    torch.save(average_conv_1_feature_map, 'weights/conv/average_conv_feature_map/average_conv_1_feature_map.pth')
    torch.save(average_conv_2_feature_map, 'weights/conv/average_conv_feature_map/average_conv_2_feature_map.pth')
    torch.save(average_conv_3_feature_map, 'weights/conv/average_conv_feature_map/average_conv_3_feature_map.pth')
    torch.save(average_conv_4_feature_map, 'weights/conv/average_conv_feature_map/average_conv_4_feature_map.pth')
    
    return True

def estimate_modularity_conv(model, target_layer, q_kc, test_loader):
    
    # step 1
    sensitivity_on_input_score = calculate_sensitivity_on_input_conv_single_dataloader(model, target_layer, q_kc, test_loader)
    
    # step 2.1 : load pre-feature maps
    if target_layer == 'conv1':     
        average_conv_feature_map = torch.load('weights/conv/average_conv_feature_map/average_conv_1_feature_map.pth')
    elif target_layer == 'conv2':     
        average_conv_feature_map = torch.load('weights/conv/average_conv_feature_map/average_conv_2_feature_map.pth')
    elif target_layer == 'conv3':     
        average_conv_feature_map = torch.load('weights/conv/average_conv_feature_map/average_conv_3_feature_map.pth')
    elif target_layer == 'conv4':     
        average_conv_feature_map = torch.load('weights/conv/average_conv_feature_map/average_conv_4_feature_map.pth')
    
    # step 2.2: calculate similarity on output
    weighted_average = 0
    for data_afhq in test_loader:
        input_tensor = data_afhq[0] 
        
        if target_layer == 'conv1':     
            conv_feature_map = model.get_hidden_layer_conv1_output(input_tensor)
        elif target_layer == 'conv2':     
            conv_feature_map = model.get_hidden_layer_conv2_output(input_tensor)
        elif target_layer == 'conv3':     
            conv_feature_map = model.get_hidden_layer_conv3_output(input_tensor)
        elif target_layer == 'conv4':     
            conv_feature_map = model.get_hidden_layer_conv4_output(input_tensor)
        
        weighted_average += weighted_statistics_conv(conv_feature_map, average_conv_feature_map, q_kc)
        
    similarity_on_output = weighted_average/len(test_loader)
    
    return sensitivity_on_input_score + similarity_on_output

if __name__ == "__main__":
    
    ###################################################
    ###################################################
    
    ################################### linear layer #############################
    '''
    model = linear_model()
    model.load_state_dict(torch.load('weights/best_model.pth'))
    ckp = torch.load('weights/best_model.pth')
    # model.eval()

    transform = transforms.Compose([transforms.ToTensor(),])

    mnist_train = NoisyMNISTDataset(
        image_folder='/Users/lan/my_codes/neuron_analysis/InnerSightNet/data/noisy_mnist/train',
        labels_file=os.path.join('/Users/lan/my_codes/neuron_analysis/InnerSightNet/data/noisy_mnist/train', 'labels.txt'),
        transform=transform
        )
    mnist_test = NoisyMNISTDataset(
        image_folder='/Users/lan/my_codes/neuron_analysis/InnerSightNet/data/noisy_mnist/test',
        labels_file=os.path.join('/Users/lan/my_codes/neuron_analysis/InnerSightNet/data/noisy_mnist/test', 'labels.txt'),
        transform=transform
        )

    train_size = int(0.8 * len(mnist_train))
    val_size = len(mnist_train) - train_size
    train_dataset, val_dataset = random_split(mnist_train, [train_size, val_size])
    bz = 1
    train_loader = DataLoader(train_dataset, batch_size=bz, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=bz, shuffle=True)
    test_loader = DataLoader(mnist_test, batch_size=bz, shuffle=True)

    target_layer = 'fc3'
    neuron_indices = [[0, 15, 31],
                      [1, 13, 18, 24, 25, 30],
                      [2, 4, 7, 9, 27],
                      [3, 6, 26, 29],
                      [5, 20, 22],
                      [8, 14, 16, 21, 28],
                      [10, 17, 19, 23],
                      [11, 12]]
    '''
    # calculate_sensitivity_on_input(model, target_layer, neuron_indices)
    
    # mean_jacobian = calculate_sensitivity_on_input_jacobian(model, target_layer, test_loader)
    
    '''
    mean_jacobian = calculate_sensitivity_on_output_jacobian(model, target_layer, test_loader)
    
    matrix_np = mean_jacobian.numpy()
    plt.imshow(matrix_np, cmap='jet', interpolation='nearest')
    plt.colorbar()
    plt.savefig('/Users/lan/my_codes/neuron_analysis/InnerSightNet/matrix_hd.png', dpi=1000)
    plt.close()
    '''
    
    # layer_averages = compute_layer_averages(model, test_loader)
    # save_layer_averages(layer_averages, save_dir='/Users/lan/my_codes/neuron_analysis/InnerSightNet/weights/linear_each_layer_output/')
    
    '''
    layer_output = torch.load('weights/linear_each_layer_output/fc3_average.pth')
    average_output = torch.load('weights/linear_each_layer_output/fc3_average.pth')
    weighted_statistic_value = weighted_statistics(layer_output, average_output, neuron_indices)
    print(weighted_statistic_value.item())
    '''
    '''
    # sensitivity_on_input, weighted_average_score = estimate_modularity(neuron_indices, target_layer, model, test_loader)
    target_layer = 'fc1'
    mean_jacobian_128 = calculate_sensitivity_on_input_jacobian(model, target_layer, test_loader)
    torch.save(mean_jacobian_128, 'weights/mean_jacobian/mean_jacobian_128.pth')
    target_layer = 'fc2'
    mean_jacobian_64 = calculate_sensitivity_on_input_jacobian(model, target_layer, test_loader)
    torch.save(mean_jacobian_64, 'weights/mean_jacobian/mean_jacobian_64.pth')
    target_layer = 'fc3'
    mean_jacobian_32 = calculate_sensitivity_on_input_jacobian(model, target_layer, test_loader)
    torch.save(mean_jacobian_32, 'weights/mean_jacobian/mean_jacobian_32.pth')
    '''
    
    ################################### conv layer #############################
    
    model = CatDogCNN()
    model.load_state_dict(torch.load('weights/conv_best_conv_model_img_size128.pth'))
    model = model.eval()

    transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    test_data = datasets.ImageFolder(root='data/Cat_Dog_data/test/', transform=transform)
    test_loader_cal_weights = DataLoader(test_data, batch_size=1, shuffle=False)
    
    q_kc = [[0, 3, 4, 5, 6, 8, 10, 13, 18, 26, 28, 29],
            [1, 2, 20, 27, 31],
            [7, 12, 14, 17, 19, 22, 23, 25, 30],
            [9, 11, 15, 16, 21, 24]]
    target_layer = 'conv3'
    
    # calculate average feature map
    # calculate_average_feature_conv(model, test_loader_cal_weights)
    
    # calculate similarity of conv: step 1
    # sensitivity_similarity_q_kc_test_dataloader = calculate_sensitivity_on_input_conv_single_dataloader(model, target_layer, q_kc, test_loader_cal_weights)
    
    test_data_subset = Subset(test_data, range(16))
    test_loader = DataLoader(test_data_subset, batch_size=1, shuffle=False)
    q_kc = [[0, 1, 2,3,4,5,6,7,10,11,12,13,14,15,16,17,18,20,21,22,24,25,28,29,30,31],
            [8, 9, 23],
            [19, 26, 27]]
    
    sensitivity_similarity_q_kc_test_dataloader = calculate_sensitivity_on_input_conv_single_dataloader(model, target_layer, q_kc, test_loader)
    print(sensitivity_similarity_q_kc_test_dataloader)
    
    # sensitivity_on_input_score, similarity_on_output = estimate_modularity_conv(model, target_layer, q_kc, test_loader)
    # print(sensitivity_on_input_score, similarity_on_output)