import os
import shutil
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
from torchvision import datasets
from torch.utils.data import DataLoader
from torch.nn import functional as F
from tqdm import tqdm
import matplotlib.pyplot as plt
from PIL import Image

from torch.utils.data import DataLoader, random_split
from collections import defaultdict
import numpy as np
import torchvision
import random
from scipy.stats import entropy
from sklearn.metrics import mutual_info_score
from collections import deque

from torch.nn.functional import cosine_similarity
from torchvision.transforms import InterpolationMode
from train_conv_model import CatDogCNN

target_shape = (128, 128)

def get_activation(layer, input, model):
    activation = {}
    def hook(model, input, output):
        activation['value'] = output.detach()
    handle = layer.register_forward_hook(hook)
    model(input)
    handle.remove()
    return activation['value']

def cal_strength(vec1, vec2):
    
    non_negative_vec1 = torch.where(vec1 < 0, torch.ones_like(vec1)*0.001, vec1)
    non_negative_vec2 = torch.where(vec2 < 0, torch.ones_like(vec2)*0.001, vec2)

    negative_vec1 = torch.where(vec1 > 0, torch.ones_like(vec1)*0.001, -vec1)
    negative_vec2 = torch.where(vec2 > 0, torch.ones_like(vec2)*0.001, -vec2)

    prob_non_negative_vec1 = non_negative_vec1.flatten() / non_negative_vec1
    prob_non_negative_vec2 = non_negative_vec2.flatten() / non_negative_vec2

    prob_negative_vec1 = negative_vec1.flatten() / torch.sum(negative_vec1)
    prob_negative_vec2 = negative_vec2.flatten() / torch.sum(negative_vec2)

    kl_div_non_negative = entropy(prob_non_negative_vec1.numpy(), prob_non_negative_vec2.numpy())
    kl_div_negative = entropy(prob_negative_vec1.numpy(), prob_negative_vec2.numpy())

    return kl_div_non_negative+kl_div_negative

def cal_weight(layer1, layer2):
    for data in tqdm(test_loader_cal_weights):
        
        image = data[0]
        input_tensor = transform(img).unsqueeze(0)

        activation1_stack = get_activation(layer1, input_tensor, model)
        activation2_stack = get_activation(layer2, input_tensor, model)
        
        i = 0
        j = 0
        
        for i in range(activation1_stack.shape[1]):
            for j in range(activation2_stack.shape[1]):
                resized_tensor_interp_1 = torch.nn.functional.interpolate(activation1_stack[0][i].unsqueeze(0).unsqueeze(0), 
                                                        size=target_shape, mode='bilinear').squeeze(0).squeeze(0)
                resized_tensor_interp_2 = torch.nn.functional.interpolate(activation2_stack[0][j].unsqueeze(0).unsqueeze(0), 
                                                        size=target_shape, mode='bilinear').squeeze(0).squeeze(0)
                
                vec1 = resized_tensor_interp_1.view(-1)
                vec2 = resized_tensor_interp_2.view(-1)
                
                similarity = cosine_similarity(vec1.unsqueeze(0), vec2.unsqueeze(0))
                sign = similarity/abs(similarity)
                
                strength = cal_strength(vec1, vec2)
                weight[i, j] += (sign*strength).numpy()
    
    weight_avg = weight/len(test_loader_cal_weights)
    
    return weight_avg
if __name__ == "__main__":
    transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    test_data = datasets.ImageFolder(root='data/Cat_Dog_data/test/', transform=transform)
    test_loader_cal_weights = DataLoader(test_data, batch_size=1, shuffle=False)
    img = Image.open("data/Cat_Dog_data/test/cat/flickr_cat_000008.jpg")
    input_tensor = transform(img).unsqueeze(0)

    model = CatDogCNN()
    model.load_state_dict(torch.load('weights/conv_best_conv_model_img_size128.pth'))
    model = model.eval()

    layer1 = model.conv1
    layer2 = model.conv2
    activation1_stack = get_activation(layer1, input_tensor, model)
    activation2_stack = get_activation(layer2, input_tensor, model)
    weight = torch.zeros((activation1_stack.shape[1],activation2_stack.shape[1]))
    weight_avg_12 = cal_weight(layer1, layer2)
    torch.save(weight_avg_12, 'weights/conv/neuronization/weight_avg_12.pth')

    layer2 = model.conv2
    layer3 = model.conv3
    activation2_stack = get_activation(layer2, input_tensor, model)
    activation3_stack = get_activation(layer3, input_tensor, model)
    weight = torch.zeros((activation2_stack.shape[1],activation3_stack.shape[1]))
    weight_avg_23 = cal_weight(layer2, layer3)
    torch.save(best_model_params_32, 'weights/conv/neuronization/weight_avg_23.pth')

    layer3 = model.conv3
    layer4 = model.conv4
    activation3_stack = get_activation(layer3, input_tensor, model)
    activation4_stack = get_activation(layer4, input_tensor, model)
    weight = torch.zeros((activation3_stack.shape[1],activation4_stack.shape[1]))
    weight_avg_34 = cal_weight(layer3, layer4)
    torch.save(weight_avg_34, 'weights/conv/neuronization/weight_avg_34.pth')

    layer4 = model.conv4
    layer5 = model.conv5
    activation4_stack = get_activation(layer4, input_tensor, model)
    activation5_stack = get_activation(layer5, input_tensor, model)
    weight = torch.zeros((activation4_stack.shape[1],activation5_stack.shape[1]))
    weight_avg_45 = cal_weight(layer4, layer5)
    torch.save(weight_avg_45, 'weights/conv/neuronization/weight_avg_45.pth')