import os
import shutil
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
from torchvision import datasets
from torch.utils.data import DataLoader
from torch.nn import functional as F
from tqdm import tqdm
import matplotlib.pyplot as plt
from PIL import Image

from torch.utils.data import DataLoader, random_split
from collections import defaultdict
import numpy as np
import torchvision
import random
from scipy.stats import entropy
from sklearn.metrics import mutual_info_score
from collections import deque

from torch.nn.functional import cosine_similarity
from torchvision.transforms import InterpolationMode

'''
class CatDogCNN(nn.Module):
    def __init__(self):
        super(CatDogCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.conv5 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        
        self.fc1 = nn.Linear(256 * 4 * 4, 512)
        self.fc2 = nn.Linear(512, 128)
        self.fc3 = nn.Linear(128, 64)
        self.fc4 = nn.Linear(64, 2)
        self.pool = nn.MaxPool2d(2, 2)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = self.pool(F.relu(self.conv4(x)))
        x = self.pool(F.relu(self.conv5(x)))
        
        x = x.view(-1, 256 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.fc4(x)
        return x
    
    def get_hidden_layer_conv1_output(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        return x
    
    def get_hidden_layer_conv2_output(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        return x
    
    def get_hidden_layer_conv3_output(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        return x
    
    def get_hidden_layer_conv4_output(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = self.pool(F.relu(self.conv4(x)))
        return x
    
activation = {}

def get_activation(name):
    def hook(model, input, output):
        activation[name] = output.detach()
    return hook

transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

model = CatDogCNN()
model.load_state_dict(torch.load('weights/conv_best_conv_model_img_size128.pth'))
# model = model.eval()
model.train()

for param in model.parameters():
    param.requires_grad = True

test_data = datasets.ImageFolder(root='data/Cat_Dog_data/test/', transform=transform)
test_loader_cal_weights = DataLoader(test_data, batch_size=1, shuffle=False)
img = Image.open("data/Cat_Dog_data/test/cat/flickr_cat_000008.jpg")
input_tensor = transform(img).unsqueeze(0).requires_grad_(True)
model.conv1.register_forward_hook(get_activation('conv1'))

input_tensor.requires_grad = True
layer_activations = None

def hook(module, input, output):
    global layer_activations
    layer_activations = output

def get_sensitivity(kernel_index):
    if input_tensor.grad is not None:
        input_tensor.grad.zero_()
          
    hook_handle = model.conv1.register_forward_hook(hook)
    model(input_tensor)
    hook_handle.remove()
    
    conv_kernel_output = layer_activations[:, kernel_index, :, :]
    conv_kernel_output.backward(conv_kernel_output)
    
    sensitivity_map = input_tensor.grad.data.clone()
    sensitivity_map = sensitivity_map[0, 0]
    
    vec = sensitivity_map.view(-1)
    return vec

def calculate_sensitivity_similarity_between_two_neurons(i,j):
    
    vec1 = get_sensitivity(i)
    vec2 = get_sensitivity(j)
    
    similarity = cosine_similarity(vec1.unsqueeze(0), vec2.unsqueeze(0))
    sign = similarity/abs(similarity)
    strength = cal_strength(vec1, vec2) 
    
    return sign*strength

sensitivity_map_1 = get_sensitivity(1)
sensitivity_map_2 = get_sensitivity(6)


from conv_neuronization import cal_strength
for i in range(16):
    for j in range(i+1,16):
        
        value = calculate_sensitivity_similarity_between_two_neurons(i,j)
        print(value)

plt.imshow(sensitivity_map_1.cpu().numpy(), cmap='hot')
plt.colorbar()
plt.show()

plt.imshow(sensitivity_map_2.cpu().numpy(), cmap='hot')
plt.colorbar()
plt.show()
'''
import torch

def calculate_statistics(data, mean):
    deviations = data - mean 
    std_dev = torch.std(deviations, unbiased=True)
    ncs = 1 / (1 + std_dev)
    consistency = torch.sum(torch.abs(deviations) * torch.sign(deviations)) / (torch.sum(torch.abs(deviations)) + torch.tensor(1e-6))
    return ncs * torch.abs(consistency)

def weighted_statistics(layer_output, average_output, communities):
    community_stats = []
    weights = []

    for community in communities:
        community_data = layer_output[:, community, :, :]
        community_mean = average_output[:, community, :, :]

        community_data_flatten = community_data.flatten()
        community_mean_flatten = community_mean.flatten()
        stat = calculate_statistics(community_data_flatten, community_mean_flatten)
        community_stats.append(stat)
        weights.append(community_data_flatten.numel())

    community_stats = torch.tensor(community_stats)
    weights = torch.tensor(weights, dtype=community_stats.dtype)

    weighted_average = torch.sum(community_stats * weights) / torch.sum(weights)
    return weighted_average

layer_output = torch.rand(1, 16, 32, 32)  
average_output = torch.rand(1, 16, 32, 32)  
communities = [[1, 2, 4], [0, 5, 7], [12, 14]]
for i in range(16):
    
    result = weighted_statistics(layer_output, average_output, communities)
    print(result)
