import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import time
from tqdm import tqdm
# ❶ 방금 확인한 클래스 이름만 매칭
from spikingjelly.activation_based.neuron import LIFNode   # ← 수정된 import 경로

import os, re, hashlib, pickle, random   
from pathlib import Path
from typing import Optional, Dict, List, Union
import json
import numpy as np
import pandas as pd


@torch.no_grad()
def safe_mean(tensors):
    """Float 파라미터만 평균, 그 외(dtype!=float)는 첫 번째 값 그대로 반환"""
    if torch.is_floating_point(tensors[0]):
        return torch.mean(torch.stack(tensors), dim=0)
    # bias/num_batches_tracked 등은 그대로 복사
    return tensors[0].clone()


def setup_seed(seed: int) -> None:
    """재현성을 위한 시드 설정"""
    import random
    import numpy as np
    import torch
    
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


# ════════════════════════════════════════════════════════════════
# ChannelRegistry ─ 클라이언트별 채널 매핑 관리
# ════════════════════════════════════════════════════════════════
class ChannelRegistry:
    """
    각 클라이언트의 채널 매핑을 관리하는 클래스
    
    Attributes
    ----------
    global_map : Dict[str, int]
        전역 채널 이름 → 인덱스 매핑
    client_maps : Dict[str, Dict[str, int]]
        클라이언트별 로컬 채널 이름 → 인덱스 매핑
    max_ch : int
        최대 채널 수
    """
    
    def __init__(self, global_map: Dict[str, int], client_maps: Dict[str, Dict[str, int]]):
        self.global_map = global_map
        self.client_maps = client_maps
        self.max_ch = len(global_map)
    
    @classmethod
    def from_yaml_files(
        cls,
        client_meta_paths: Dict[str, Path]  # client_id → yml_path
    ) -> "ChannelRegistry":
        """
        각 클라이언트의 channel_map.yml 들을 읽어
        • 전역 채널 이름 집합(합집합) 생성
        • channel_name → global_index 딕셔너리 작성
        """
        try:
            import yaml
        except ImportError as e:  # yaml 미설치 시 안내
            raise ImportError("`pyyaml` 가 필요합니다:  pip install pyyaml") from e

        # 1) 전역 채널 이름 수집
        channel_set: set = set()
        client_maps: Dict[str, Dict[str, int]] = {}
        for cid, yml_path in client_meta_paths.items():
            with open(yml_path, "r", encoding="utf-8") as fp:
                local_map = yaml.safe_load(fp)  # {'LIGHTING':0, 'HVAC':1, ...}
            client_maps[cid] = local_map
            channel_set.update(local_map.keys())

        # 2) 전역 인덱스(알파벳 정렬) 부여 → 재현성 확보
        global_map = {ch: idx for idx, ch in enumerate(sorted(channel_set))}
        return cls(global_map=global_map, client_maps=client_maps)

    # ------------------------------------------------------------------ #
    # DataFrame / ndarray 헬퍼                                            #
    # ------------------------------------------------------------------ #
    def reorder(self, df: pd.DataFrame) -> pd.DataFrame:
        """
        입력 DataFrame 의 열을 전역 인덱스 순서로 재정렬  
        • 없는 채널은 0 으로 pad
        """
        # 1) 전역 순서 리스트
        ordered_cols: List[str] = [
            ch for ch, _ in sorted(self.global_map.items(), key=lambda x: x[1])
        ]

        # 2) DataFrame 에 없는 채널 → 0 채널 추가
        missing = [c for c in ordered_cols if c not in df.columns]
        if missing:
            # zeros DataFrame을 한 번만 생성하여 concat
            zeros = pd.DataFrame(
                0.0,
                index=df.index,
                columns=missing
            )
            df = pd.concat([df, zeros], axis=1)

        # 3) 최종 정렬
        return df.loc[:, ordered_cols]

    def restore_order(self, df: pd.DataFrame, client_id: str) -> pd.DataFrame:
        """
        전역 순서 DataFrame → 특정 클라이언트(channel_map.yml 기준) 원래 순서로 복원  
        • pad 된 0‧NaN 채널은 자동 제거
        """
        cli_map = self.client_maps.get(client_id)
        if cli_map is None:
            raise KeyError(f"Unknown client_id: {client_id}")

        ordered_cols = [ch for ch, _ in sorted(cli_map.items(), key=lambda x: x[1])]
        # 해당 클라이언트가 갖지 않은 채널은 제거
        ordered_cols = [c for c in ordered_cols if c in df.columns]
        return df.loc[:, ordered_cols]

    # ------------------------------------------------------------------ #
    # ndarray ↔ DataFrame 변환                                            #
    # ------------------------------------------------------------------ #
    def to_dataframe(
        self,
        arr: np.ndarray,  # shape = (max_ch, L)
        client_id: str,
        index: pd.Index = None,
    ) -> pd.DataFrame:
        """모델 출력(또는 라벨)을 DataFrame 으로 변환하여 시각화 편의 제공"""
        if arr.ndim != 2:
            raise ValueError("expected arr shape (max_ch, L)")
        arr = arr.copy()

        # 채널 이름 배치
        ordered_cols = [ch for ch, _ in sorted(self.global_map.items(), key=lambda x: x[1])]
        df = pd.DataFrame(arr.T, columns=ordered_cols, index=index)

        # 클라이언트 원래 순서 + 없는 채널 제거
        return self.restore_order(df, client_id)

    # ------------------------------------------------------------------ #
    # 유틸                                                                #
    # ------------------------------------------------------------------ #
    def get_mask(self, client_id: str) -> np.ndarray:
        """
        해당 클라이언트가 보유한 채널이면 1, 없으면 0인 mask (shape = [max_ch])
        """
        mask = np.zeros(self.max_ch, dtype=np.float32)
        cli_map = self.client_maps.get(client_id, {})
        for ch in cli_map.keys():
            idx = self.global_map[ch]
            mask[idx] = 1.0
        return mask

    def save(self, file_path: str) -> None:
        """
        ChannelRegistry를 JSON 파일로 저장
        
        Parameters
        ----------
        file_path : str
            저장할 파일 경로
        """
        import json
        
        data = {
            "global_map": self.global_map,
            "client_maps": self.client_maps,
            "max_ch": self.max_ch
        }
        
        with open(file_path, 'w', encoding='utf-8') as f:
            json.dump(data, f, ensure_ascii=False, indent=2)

