import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
from collections import defaultdict
import numpy as np
from tqdm import tqdm
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import torch.nn.functional as F
import random
from scipy.stats import entropy
from sklearn.metrics import mutual_info_score
import os
import torch
import matplotlib.pyplot as plt
import matplotlib
from torch.utils.data import Dataset
from PIL import Image

# from similarity_assessment import estimate_modularity

def threshold_tensor_for_activate(tensor, threshold):

    return torch.where(tensor > threshold, torch.tensor(0.999), torch.tensor(0.001))

def threshold_tensor_for_inhibition(tensor, threshold):
    
    return torch.where(tensor < -threshold, torch.tensor(0.999), torch.tensor(0.001))

# E-step
def e_step(pi_c, tau_ci_act, tau_ci_inh, tau_cj_prime_act, tau_cj_prime_inh,
           A_ik_act, A_ik_inh, B_jk_act, B_jk_inh, num_observations, num_states):
    
    q_kc = np.zeros((num_observations, num_states))
    for k in range(num_observations):
        log_numerator = np.zeros(num_states)
        for c in range(num_states):
            log_numerator[c] = np.log(pi_c[c])
            
            log_numerator[c] += np.sum(np.log(tau_ci_act[c, :]) * A_ik_act[:, k] + 
                                       np.log((1-tau_ci_act[c, :])) * (1-A_ik_act[:, k]))
            log_numerator[c] += np.sum(np.log(tau_ci_inh[c, :]) * A_ik_inh[:, k] + 
                                       np.log((1-tau_ci_inh[c, :])) * (1-A_ik_inh[:, k]))            
            log_numerator[c] += np.sum(np.log(tau_cj_prime_act[c, :]) * B_jk_act.T[:, k] + 
                                       np.log((1-tau_cj_prime_act[c, :])) * (1-B_jk_act.T[:, k]))
            log_numerator[c] += np.sum(np.log(tau_cj_prime_inh[c, :]) * B_jk_inh.T[:, k] + 
                                       np.log((1-tau_cj_prime_inh[c, :])) * (1-B_jk_inh.T[:, k]))
            
        log_denominator = np.logaddexp.reduce(log_numerator)
        q_kc[k, :] = np.exp(log_numerator - log_denominator)
    return q_kc

def m_step(q_kc, A_ik_act, A_ik_inh, B_jk_act, B_jk_inh, num_observations, 
           num_states, num_features_i, num_features_j, pi_c,
          tau_ci_act, tau_ci_inh, tau_cj_prime_act, tau_cj_prime_inh):
    
    for c in range(num_states):
        
        tau_ci_act[c, :] = np.sum(A_ik_act * q_kc[:, c].reshape(1, num_observations), axis=1)
        tau_ci_inh[c, :] = np.sum(A_ik_inh * q_kc[:, c].reshape(1, num_observations), axis=1)
        tau_cj_prime_act[c, :] = np.sum(B_jk_act * q_kc[:, c].reshape(num_observations, 1), axis=0)
        tau_cj_prime_inh[c, :] = np.sum(B_jk_inh * q_kc[:, c].reshape(num_observations, 1), axis=0)
        
        sum_q_kc_c = np.sum(q_kc[:, c])
        tau_ci_act[c, :] /= sum_q_kc_c if sum_q_kc_c > 0 else 1
        tau_ci_inh[c, :] /= sum_q_kc_c if sum_q_kc_c > 0 else 1
        tau_cj_prime_act[c, :] /= sum_q_kc_c if sum_q_kc_c > 0 else 1
        tau_cj_prime_inh[c, :] /= sum_q_kc_c if sum_q_kc_c > 0 else 1

    pi_c /= np.sum(pi_c)
    for c in range(num_states):
        tau_ci_act[c, :] /= (np.sum(tau_ci_act[c, :])+np.sum(tau_ci_inh[c, :]))
        tau_ci_inh[c, :] /= (np.sum(tau_ci_act[c, :])+np.sum(tau_ci_inh[c, :]))
        tau_cj_prime_act[c, :] /= (np.sum(tau_cj_prime_act[c, :])+np.sum(tau_cj_prime_inh[c, :]))
        tau_cj_prime_inh[c, :] /= (np.sum(tau_cj_prime_act[c, :])+np.sum(tau_cj_prime_inh[c, :]))
    
    return pi_c, tau_ci_act, tau_ci_inh, tau_cj_prime_act, tau_cj_prime_inh

def filter_array(arr):
    filtered_arr = np.where(arr < 0.01, 0, arr)
    return filtered_arr

def nonzero_coordinates(matrix):
    coordinates = np.transpose(np.nonzero(matrix))
    return coordinates

def get_neuron_lists(q_k_c):
    filtered_q_k_c = filter_array(q_k_c)
    coordinates = nonzero_coordinates(filtered_q_k_c)
    
    groups = defaultdict(list)
    for row in coordinates:
        groups[row[1]].append(row[0])
        list_of_lists = list(groups.values())
    
    return list_of_lists

def cal_best_community_num(weight1, weight2, threshold1, threshold2, threshold3, threshold4, num_iterations_, num_starts_, num_states_begin, num_states_end):
    
    # num_iterations_=200, num_starts_=100
    A_ik_act = threshold_tensor_for_activate(weight1, threshold=threshold1).numpy().T
    A_ik_inh = threshold_tensor_for_inhibition(weight1, threshold=threshold2).numpy().T
    B_jk_act = threshold_tensor_for_activate(weight2, threshold=threshold3).numpy().T
    B_jk_inh = threshold_tensor_for_inhibition(weight2, threshold=threshold4).numpy().T

    l = A_ik_act.shape[1]
    
    num_iterations = num_iterations_
    num_starts = num_starts_
    num_features_i = A_ik_act.shape[0] 
    num_features_j = B_jk_act.shape[1]
    num_observations = A_ik_act.shape[1]
    
    best_score_global = -np.inf
    best_c_global = None
    best_model_params_global = None
    best_score_local_list = []
    num_states_values = range(1, A_ik_act.shape[1])
    best_c = None
    best_score = -np.inf
    best_model_params = None
    
    for num_states in tqdm(range(num_states_begin, num_states_end)):
    
        
        best_score_local = -np.inf
        best_c_local = None
        best_model_params_local = None
        
        # for num_states in tqdm(num_states_values):
        for start in range(num_starts):
            pi_c = np.random.rand(num_states)
            tau_ci_act = np.random.rand(num_states, num_features_i)
            tau_ci_inh = np.random.rand(num_states, num_features_i)
            tau_cj_prime_act = np.random.rand(num_states, num_features_j)
            tau_cj_prime_inh = np.random.rand(num_states, num_features_j)

            pi_c /= pi_c.sum()
            tau_ci_act /= (tau_ci_act.sum(axis=1, keepdims=True) + tau_ci_inh.sum(axis=1, keepdims=True))
            tau_ci_inh /= (tau_ci_act.sum(axis=1, keepdims=True) + tau_ci_inh.sum(axis=1, keepdims=True))
            tau_cj_prime_act /= (tau_cj_prime_act.sum(axis=1, keepdims=True) + tau_cj_prime_inh.sum(axis=1, keepdims=True))
            tau_cj_prime_inh /= (tau_cj_prime_act.sum(axis=1, keepdims=True) + tau_cj_prime_inh.sum(axis=1, keepdims=True))

            for _ in range(num_iterations):
                q_kc = e_step(pi_c, tau_ci_act, tau_ci_inh, tau_cj_prime_act, tau_cj_prime_inh,
                       A_ik_act, A_ik_inh, B_jk_act, B_jk_inh, num_observations, num_states)
                pi_c, tau_ci_act, tau_ci_inh, tau_cj_prime_act, tau_cj_prime_inh = m_step(q_kc, A_ik_act, A_ik_inh, B_jk_act, B_jk_inh, num_observations, 
                                                                                      num_states, num_features_i, num_features_j, pi_c,
                                                                                      tau_ci_act, tau_ci_inh, tau_cj_prime_act, tau_cj_prime_inh)
            
                pi_c /= pi_c.sum()
                tau_ci_act /= (tau_ci_act.sum(axis=1, keepdims=True) + tau_ci_inh.sum(axis=1, keepdims=True))
                tau_ci_inh /= (tau_ci_act.sum(axis=1, keepdims=True) + tau_ci_inh.sum(axis=1, keepdims=True))
                tau_cj_prime_act /= (tau_cj_prime_act.sum(axis=1, keepdims=True) + tau_cj_prime_inh.sum(axis=1, keepdims=True))
                tau_cj_prime_inh /= (tau_cj_prime_act.sum(axis=1, keepdims=True) + tau_cj_prime_inh.sum(axis=1, keepdims=True))
    
            lists = get_neuron_lists(q_kc)
            score = estimate_modularity(lists, weight1, weight2)
    
            if score > best_score_local:
                best_score_local = score
                best_c_local = num_states
                best_model_params_local = [pi_c, tau_ci_act, tau_ci_inh, tau_cj_prime_act, tau_cj_prime_inh, q_kc]
            
            best_score_local_list.append(best_score_local)

            if best_score_local > best_score_global:
                best_score_global = best_score_local
                best_c_global = best_c_local
                best_model_params_global = best_model_params_local

    print(f"Best C value is: {best_c_global}")
    '''
    index_list = list(range(len(best_score_local_list)))
    plt.figure(figsize=(10, 5))
    plt.plot(index_list, best_score_local_list, marker='o')

    plt.title('Line Graph Example')
    plt.xlabel('Index')
    plt.ylabel('Value')
    plt.grid(True)
    plt.show()
    '''
    best_model_params_local.append(best_score_local_list)
    
    return best_model_params_global

class NoisyMNISTDataset(Dataset):
    def __init__(self, image_folder, labels_file, transform=None):
        self.image_folder = image_folder
        self.labels_file = labels_file
        self.transform = transform
        self.labels = {}
        
        with open(self.labels_file, 'r') as f:
            for line in f.readlines():
                image_filename, label = line.strip().split()
                self.labels[image_filename] = int(label)
        
        self.image_filenames = list(self.labels.keys())

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        image_filename = self.image_filenames[idx]
        image_path = os.path.join(self.image_folder, image_filename)
        image = Image.open(image_path).convert('L') 
        label = self.labels[image_filename]
        
        if self.transform:
            image = self.transform(image)
        
        return image, label

def compute_layer_averages(model, data_loader, device='cpu'):
    model.eval()  # Set the model to evaluation mode
    model.to(device)

    # Initialize accumulators for each layer
    fc1_accumulator = torch.zeros(128, device=device)
    fc2_accumulator = torch.zeros(64, device=device)
    fc3_accumulator = torch.zeros(32, device=device)
    fc4_accumulator = torch.zeros(10, device=device)
    num_samples = 0

    with torch.no_grad():  # No need to track gradients for this operation
        for data, _ in tqdm(data_loader):
            data = data.to(device)
            num_samples += data.size(0)  # Increment total count of samples
            
            # Forward pass to get the outputs of each layer
            fc1_output = model.get_hidden_layer_fc1_output(data)
            fc2_output = model.get_hidden_layer_fc2_output(data)
            fc3_output = model.get_hidden_layer_fc3_output(data)
            final_output = model(data)

            # Accumulate the sum of all outputs of each layer
            fc1_accumulator += fc1_output.sum(dim=0)
            fc2_accumulator += fc2_output.sum(dim=0)
            fc3_accumulator += fc3_output.sum(dim=0)
            fc4_accumulator += final_output.sum(dim=0)
    
    # Calculate the average for each layer
    fc1_average = fc1_accumulator / num_samples
    fc2_average = fc2_accumulator / num_samples
    fc3_average = fc3_accumulator / num_samples
    fc4_average = fc4_accumulator / num_samples

    return fc1_average, fc2_average, fc3_average, fc4_average

def save_layer_averages(averages, save_dir='layer_averages'):
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    for i, layer_average in enumerate(averages, 1):
        torch.save(layer_average, os.path.join(save_dir, f'fc{i}_average.pth'))
