import torch
from challenge.custom.lif import CustomLIFGroup
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

def get_lif_time_constants(model):
    """
    遍历模型中所有的LIF神经元组，获取膜电位时间常数和突触时间常数
    
    参数:
        model: CustomRecurrentSpikingModel实例 - 包含神经元组的模型对象
        
    返回:
        list: 每个元素为字典，包含LIF组名称、膜电位时间常数和突触时间常数
    """
    results = []
    softplus = torch.nn.Softplus()  # 用于计算可学习时间常数
    
    # 遍历模型中的所有神经元组
    for group in model.groups:
        # 仅处理CustomLIFGroup类型的神经元组
        if isinstance(group, CustomLIFGroup):
            group_info = {
                "group_name": group.name,
                "tau_mem": None,
                "tau_syn": None
            }
            
            # 计算膜电位时间常数 (tau_mem)
            if group.learn_timescales:
                # 根据不同的前向传播方法计算时间常数
                if group.memsyn_het_forward_method == 'highpass':
                    tau_mem = group.tau_mem * softplus(group.mem_param)
                elif group.memsyn_het_forward_method == 'bandpass':
                    tau_mem = (group.memsyn_bandpass_high_ratio_cut * 
                              group.tau_mem * torch.sigmoid(group.mem_param))
                else:  # original
                    tau_mem = group.tau_mem * group.mem_param
                
                # 转换为numpy数组并移动到CPU
                group_info["tau_mem"] = tau_mem.cpu().detach().numpy()
            else:
                # 时间常数不可学习时直接使用初始值
                group_info["tau_mem"] = group.tau_mem
            
            # 计算突触时间常数 (tau_syn) - 仅当不是delta突触时
            if not group.is_delta_syn:
                if group.learn_timescales:
                    if group.memsyn_het_forward_method == 'highpass':
                        tau_syn = group.tau_syn * softplus(group.syn_param)
                    elif group.memsyn_het_forward_method == 'bandpass':
                        tau_syn = (group.memsyn_bandpass_high_ratio_cut * 
                                  group.tau_syn * torch.sigmoid(group.syn_param))
                    else:  # original
                        tau_syn = group.tau_syn * group.syn_param
                    
                    group_info["tau_syn"] = tau_syn.cpu().detach().numpy()
                else:
                    group_info["tau_syn"] = group.tau_syn
            else:
                group_info["tau_syn"] = None  # delta突触没有突触时间常数
            
            results.append(group_info)
    
    return results

def merge_time_constants(lif_time_constants_list):
    """合并不同random seed的时间常数数据"""
    merged_data = {}
    # 遍历每个seed的结果
    for seed_idx, seed_data in enumerate(lif_time_constants_list):
        if seed_data is None:  # 跳过无效数据
            continue
        # 遍历每个神经元组
        for group in seed_data:
            group_name = group["group_name"]
            # 初始化组数据
            if group_name not in merged_data:
                merged_data[group_name] = {
                    "tau_mem": [],
                    "tau_syn": []
                }
            # 添加膜电位时间常数
            if group["tau_mem"] is not None:
                if isinstance(group["tau_mem"], np.ndarray):
                    merged_data[group_name]["tau_mem"].extend(group["tau_mem"].flatten().tolist())
                else:
                    merged_data[group_name]["tau_mem"].append(group["tau_mem"])
            # 添加突触时间常数
            if group["tau_syn"] is not None:
                if isinstance(group["tau_syn"], np.ndarray):
                    merged_data[group_name]["tau_syn"].extend(group["tau_syn"].flatten().tolist())
                else:
                    merged_data[group_name]["tau_syn"].append(group["tau_syn"])
    return merged_data

def plot_tau_mem_distribution(group_name, tau_mem_values, save_path=None, bin_width=0.002):
    """绘制单个神经元组的膜电位时间常数区间百分比分布图"""
    plt.figure(figsize=(10, 6))
    
    # 计算数据范围并生成区间
    data = np.array(tau_mem_values)
    min_val = data.min()
    max_val = data.max()
    bins = np.arange(min_val, max_val + bin_width, bin_width)
    
    # 计算直方图并转换为百分比
    counts, edges = np.histogram(data, bins=bins)
    percentages = counts / counts.sum() * 100
    
    # 绘制直方图
    plt.bar(edges[:-1], percentages, width=bin_width, align='edge', edgecolor='black')
    plt.title(f'膜电位时间常数分布 (Group: {group_name})')
    plt.xlabel('时间常数 (ms)')
    plt.ylabel('百分比 (%)')
    plt.xticks(edges, rotation=45)
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"膜电位时间常数分布图已保存至: {save_path}")
    else:
        plt.show()
    plt.close()

def plot_tau_syn_distribution(group_name, tau_syn_values, save_path=None, bin_width=0.002):
    """绘制单个神经元组的突触时间常数区间百分比分布图"""
    plt.figure(figsize=(10, 6))
    
    # 计算数据范围并生成区间
    data = np.array(tau_syn_values)
    min_val = data.min()
    max_val = data.max()
    bins = np.arange(min_val, max_val + bin_width, bin_width)
    
    # 计算直方图并转换为百分比
    counts, edges = np.histogram(data, bins=bins)
    percentages = counts / counts.sum() * 100
    
    # 绘制直方图
    plt.bar(edges[:-1], percentages, width=bin_width, align='edge', edgecolor='black')
    plt.title(f'突触时间常数分布 (Group: {group_name})')
    plt.xlabel('时间常数 (ms)')
    plt.ylabel('百分比 (%)')
    plt.xticks(edges, rotation=45)
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"突触时间常数分布图已保存至: {save_path}")
    else:
        plt.show()
    plt.close()