import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
from collections import defaultdict
import numpy as np
from tqdm import tqdm
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import torch.nn.functional as F
import random
from scipy.stats import entropy
from sklearn.metrics import mutual_info_score

import torch
import matplotlib.pyplot as plt
import matplotlib
import os
from train_linear_model import linear_model
from similarity_assessment import estimate_modularity
from similarity_assessment import calculate_sensitivity_on_input_jacobian
from utils import NoisyMNISTDataset
from utils import threshold_tensor_for_activate, threshold_tensor_for_inhibition, e_step, m_step, get_neuron_lists

model = linear_model()
model.load_state_dict(torch.load('weights/best_model.pth'))
ckp = torch.load('weights/best_model.pth')

weight_net_0 = ckp['fc1.weight']
weight_net_1 = ckp['fc2.weight']
weight_net_2 = ckp['fc3.weight']
weight_net_3 = ckp['fc4.weight']

transform = transforms.Compose([transforms.ToTensor(),])

mnist_train = NoisyMNISTDataset(
    image_folder='data/noisy_mnist/train',
    labels_file=os.path.join('data/noisy_mnist/train', 'labels.txt'),
    transform=transform
    )
mnist_test = NoisyMNISTDataset(
    image_folder='data/noisy_mnist/test',
    labels_file=os.path.join('data/noisy_mnist/test', 'labels.txt'),
    transform=transform
    )

train_size = int(0.8 * len(mnist_train))
val_size = len(mnist_train) - train_size
train_dataset, val_dataset = random_split(mnist_train, [train_size, val_size])
bz = 1
train_loader = DataLoader(train_dataset, batch_size=bz, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=bz, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size=bz, shuffle=True)

def cal_best_community_num(weight1, weight2, threshold1, threshold2, threshold3, threshold4, num_iterations_, num_starts_, num_states_begin, num_states_end, target_layer, model, test_loader, mean_jacobian):
    
    # num_iterations_=200, num_starts_=100
    A_ik_act = threshold_tensor_for_activate(weight1, threshold=threshold1).numpy().T
    A_ik_inh = threshold_tensor_for_inhibition(weight1, threshold=threshold2).numpy().T
    B_jk_act = threshold_tensor_for_activate(weight2, threshold=threshold3).numpy().T
    B_jk_inh = threshold_tensor_for_inhibition(weight2, threshold=threshold4).numpy().T

    l = A_ik_act.shape[1]
    
    num_iterations = num_iterations_
    num_starts = num_starts_
    num_features_i = A_ik_act.shape[0] 
    num_features_j = B_jk_act.shape[1]
    num_observations = A_ik_act.shape[1]
    
    best_score_global = -np.inf
    best_c_global = None
    best_model_params_global = None
    best_score_local_list = []
    num_states_values = range(1, A_ik_act.shape[1])
    best_c = None
    best_score = -np.inf
    best_model_params = None
    
    for num_states in tqdm(range(num_states_begin, num_states_end)):
    
        
        best_score_local = -np.inf
        best_c_local = None
        best_model_params_local = None
        
        # for num_states in tqdm(num_states_values):
        for start in range(num_starts):
            pi_c = np.random.rand(num_states)
            tau_ci_act = np.random.rand(num_states, num_features_i)
            tau_ci_inh = np.random.rand(num_states, num_features_i)
            tau_cj_prime_act = np.random.rand(num_states, num_features_j)
            tau_cj_prime_inh = np.random.rand(num_states, num_features_j)

            pi_c /= pi_c.sum()
            tau_ci_act /= (tau_ci_act.sum(axis=1, keepdims=True) + tau_ci_inh.sum(axis=1, keepdims=True))
            tau_ci_inh /= (tau_ci_act.sum(axis=1, keepdims=True) + tau_ci_inh.sum(axis=1, keepdims=True))
            tau_cj_prime_act /= (tau_cj_prime_act.sum(axis=1, keepdims=True) + tau_cj_prime_inh.sum(axis=1, keepdims=True))
            tau_cj_prime_inh /= (tau_cj_prime_act.sum(axis=1, keepdims=True) + tau_cj_prime_inh.sum(axis=1, keepdims=True))

            for _ in range(num_iterations):
                q_kc = e_step(pi_c, tau_ci_act, tau_ci_inh, tau_cj_prime_act, tau_cj_prime_inh,
                       A_ik_act, A_ik_inh, B_jk_act, B_jk_inh, num_observations, num_states)
                pi_c, tau_ci_act, tau_ci_inh, tau_cj_prime_act, tau_cj_prime_inh = m_step(q_kc, A_ik_act, A_ik_inh, B_jk_act, B_jk_inh, num_observations, 
                                                                                      num_states, num_features_i, num_features_j, pi_c,
                                                                                      tau_ci_act, tau_ci_inh, tau_cj_prime_act, tau_cj_prime_inh)
            
                pi_c /= pi_c.sum()
                tau_ci_act /= (tau_ci_act.sum(axis=1, keepdims=True) + tau_ci_inh.sum(axis=1, keepdims=True))
                tau_ci_inh /= (tau_ci_act.sum(axis=1, keepdims=True) + tau_ci_inh.sum(axis=1, keepdims=True))
                tau_cj_prime_act /= (tau_cj_prime_act.sum(axis=1, keepdims=True) + tau_cj_prime_inh.sum(axis=1, keepdims=True))
                tau_cj_prime_inh /= (tau_cj_prime_act.sum(axis=1, keepdims=True) + tau_cj_prime_inh.sum(axis=1, keepdims=True))
    
            lists = get_neuron_lists(q_kc)
            score = estimate_modularity(lists, target_layer, model, test_loader, mean_jacobian)
    
            if score > best_score_local:
                best_score_local = score
                best_c_local = num_states
                best_model_params_local = [pi_c, tau_ci_act, tau_ci_inh, tau_cj_prime_act, tau_cj_prime_inh, q_kc]
            
            best_score_local_list.append(best_score_local)
            # 更新全局最佳值
            if best_score_local > best_score_global:
                best_score_global = best_score_local
                best_c_global = best_c_local
                best_model_params_global = best_model_params_local

    print(f"Best C value is: {best_c_global}")
    '''
    index_list = list(range(len(best_score_local_list)))
    plt.figure(figsize=(10, 5))
    plt.plot(index_list, best_score_local_list, marker='o')

    plt.title('Line Graph Example')
    plt.xlabel('Index')
    plt.ylabel('Value')
    plt.grid(True)
    plt.show()
    '''
    best_model_params_local.append(best_score_local_list)
    
    return best_model_params_global

weight1 = weight_net_0
weight2 = weight_net_1
epsilon = 1e-10
target_layer = 'fc1'
mean_jacobian = torch.load('weights/mean_jacobian/mean_jacobian_128.pth')
best_model_params_128 = cal_best_community_num(weight1, weight2, 0.15, 0.1, 0.15, 0.1, 200, 100, 1, 20, target_layer, model, test_loader, mean_jacobian)
q_kc_128 = best_model_params_128[5]
torch.save(best_model_params_128, 'weights/best_linear_model_params_128.pth')

weight1 = weight_net_1
weight2 = weight_net_2
epsilon = 1e-10
target_layer = 'fc2'
mean_jacobian = torch.load('weights/mean_jacobian/mean_jacobian_64.pth')
best_model_params_128 = cal_best_community_num(weight1, weight2, 0.15, 0.1, 0.15, 0.1, 200, 100, 1, 20, target_layer, model, test_loader, mean_jacobian)
best_model_params_64 = cal_best_community_num(weight1, weight2)
q_kc_64 = best_model_params_64[5]
torch.save(best_model_params_64, 'weights/best_model_params_64.pth')

weight1 = weight_net_2
weight2 = weight_net_3
epsilon = 1e-10
target_layer = 'fc3'
mean_jacobian = torch.load('weights/mean_jacobian/mean_jacobian_32.pth')
best_model_params_128 = cal_best_community_num(weight1, weight2, 0.15, 0.1, 0.15, 0.1, 200, 100, 1, 20, target_layer, model, test_loader, mean_jacobian)
best_model_params_32 = cal_best_community_num(weight1, weight2)
q_kc_32 = best_model_params_32[5]
torch.save(best_model_params_32, 'weights/best_model_params_32.pth')