import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
from collections import defaultdict
import numpy as np
from tqdm import tqdm
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import torch.nn.functional as F
import random
from scipy.stats import entropy
from sklearn.metrics import mutual_info_score

import torch
import matplotlib.pyplot as plt
import matplotlib
import os
from torch.utils.data import DataLoader, Subset

from train_conv_model import CatDogCNN
from similarity_assessment import estimate_modularity_conv
from similarity_assessment import calculate_sensitivity_on_input_jacobian
from utils import NoisyMNISTDataset
from utils import threshold_tensor_for_activate, threshold_tensor_for_inhibition, e_step, m_step, get_neuron_lists

model = CatDogCNN()
model.load_state_dict(torch.load('weights/conv_best_conv_model_img_size128.pth'))
model = model.eval()

weight_avg_12 = torch.load('weights/conv/neuronization/weight_avg_12.pth')
weight_avg_23 = torch.load('weights/conv/neuronization/weight_avg_23.pth')
weight_avg_34 = torch.load('weights/conv/neuronization/weight_avg_34.pth')
weight_avg_45 = torch.load('weights/conv/neuronization/weight_avg_45.pth')

transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

test_data = datasets.ImageFolder(root='data/Cat_Dog_data/test/', transform=transform)
test_loader_cal_weights = DataLoader(test_data, batch_size=1, shuffle=False)

test_data_subset = Subset(test_data, range(20))
test_loader = DataLoader(test_data_subset, batch_size=1, shuffle=False)

def cal_best_community_num(weight1, weight2, threshold1, threshold2, threshold3, threshold4, num_iterations_, num_starts_, num_states_begin, num_states_end, target_layer, model, test_loader):
    
    # num_iterations_=200, num_starts_=100
    A_ik_act = threshold_tensor_for_activate(weight1, threshold=threshold1).numpy().T
    A_ik_inh = threshold_tensor_for_inhibition(weight1, threshold=threshold2).numpy().T
    B_jk_act = threshold_tensor_for_activate(weight2, threshold=threshold3).numpy().T
    B_jk_inh = threshold_tensor_for_inhibition(weight2, threshold=threshold4).numpy().T

    l = A_ik_act.shape[1]
    
    num_iterations = num_iterations_
    num_starts = num_starts_
    num_features_i = A_ik_act.shape[0] 
    num_features_j = B_jk_act.shape[1]
    num_observations = A_ik_act.shape[1]
    
    best_score_global = -np.inf
    best_c_global = None
    best_model_params_global = None
    best_score_local_list = []
    num_states_values = range(1, A_ik_act.shape[1])
    best_c = None
    best_score = -np.inf
    best_model_params = None
    
    for num_states in tqdm(range(num_states_begin, num_states_end)):
    
        
        best_score_local = -np.inf
        best_c_local = None
        best_model_params_local = None
        
        # for num_states in tqdm(num_states_values):
        for start in range(num_starts):
            pi_c = np.random.rand(num_states)
            tau_ci_act = np.random.rand(num_states, num_features_i)
            tau_ci_inh = np.random.rand(num_states, num_features_i)
            tau_cj_prime_act = np.random.rand(num_states, num_features_j)
            tau_cj_prime_inh = np.random.rand(num_states, num_features_j)

            pi_c /= pi_c.sum()
            tau_ci_act /= (tau_ci_act.sum(axis=1, keepdims=True) + tau_ci_inh.sum(axis=1, keepdims=True))
            tau_ci_inh /= (tau_ci_act.sum(axis=1, keepdims=True) + tau_ci_inh.sum(axis=1, keepdims=True))
            tau_cj_prime_act /= (tau_cj_prime_act.sum(axis=1, keepdims=True) + tau_cj_prime_inh.sum(axis=1, keepdims=True))
            tau_cj_prime_inh /= (tau_cj_prime_act.sum(axis=1, keepdims=True) + tau_cj_prime_inh.sum(axis=1, keepdims=True))

            for _ in range(num_iterations):
                q_kc = e_step(pi_c, tau_ci_act, tau_ci_inh, tau_cj_prime_act, tau_cj_prime_inh,
                       A_ik_act, A_ik_inh, B_jk_act, B_jk_inh, num_observations, num_states)
                pi_c, tau_ci_act, tau_ci_inh, tau_cj_prime_act, tau_cj_prime_inh = m_step(q_kc, A_ik_act, A_ik_inh, B_jk_act, B_jk_inh, num_observations, 
                                                                                      num_states, num_features_i, num_features_j, pi_c,
                                                                                      tau_ci_act, tau_ci_inh, tau_cj_prime_act, tau_cj_prime_inh)
            
                pi_c /= pi_c.sum()
                tau_ci_act /= (tau_ci_act.sum(axis=1, keepdims=True) + tau_ci_inh.sum(axis=1, keepdims=True))
                tau_ci_inh /= (tau_ci_act.sum(axis=1, keepdims=True) + tau_ci_inh.sum(axis=1, keepdims=True))
                tau_cj_prime_act /= (tau_cj_prime_act.sum(axis=1, keepdims=True) + tau_cj_prime_inh.sum(axis=1, keepdims=True))
                tau_cj_prime_inh /= (tau_cj_prime_act.sum(axis=1, keepdims=True) + tau_cj_prime_inh.sum(axis=1, keepdims=True))
    
            lists = get_neuron_lists(q_kc)
            score = estimate_modularity_conv(model, target_layer, q_kc, test_loader)
    
            if score > best_score_local:
                best_score_local = score
                best_c_local = num_states
                best_model_params_local = [pi_c, tau_ci_act, tau_ci_inh, tau_cj_prime_act, tau_cj_prime_inh, q_kc]
            
            best_score_local_list.append(best_score_local)

            if best_score_local > best_score_global:
                best_score_global = best_score_local
                best_c_global = best_c_local
                best_model_params_global = best_model_params_local

    print(f"Best C value is: {best_c_global}")
    best_model_params_local.append(best_score_local_list)
    
    return best_model_params_global

weight_avg_12 = torch.load('weights/conv/neuronization/weight_avg_12.pth')
weight_avg_23 = torch.load('weights/conv/neuronization/weight_avg_23.pth')
weight_avg_34 = torch.load('weights/conv/neuronization/weight_avg_34.pth')
weight_avg_45 = torch.load('weights/conv/neuronization/weight_avg_45.pth')

weight1 = weight_avg_12.T
weight2 = weight_avg_23.T
epsilon = 1e-10
target_layer = 'conv1'
best_conv12_model_params = cal_best_community_num(weight1, weight2, 1.5, 2, 1.5, 2, 200, 100, 2, 15, target_layer, model, test_loader)
torch.save(best_conv12_model_params, 'weights/conv/best_conv12_model_params.pth')

weight1 = weight_avg_23
weight2 = weight_avg_34
epsilon = 1e-10
target_layer = 'conv2'
best_conv23_model_params = cal_best_community_num(weight1, weight2, 1.5, 2, 1.5, 2, 200, 100, 2, 15, target_layer, model, test_loader)
torch.save(best_conv23_model_params, 'weights/conv/best_conv23_model_params.pth')

weight1 = weight_avg_34
weight2 = weight_avg_45
epsilon = 1e-10
target_layer = 'conv3'
best_conv34_model_params = cal_best_community_num(weight1, weight2, 1.5, 2, 1.5, 2, 200, 100, 2, 15, target_layer, model, test_loader)
torch.save(best_conv34_model_params, 'weights/conv/best_conv34_model_params.pth')
