import os
import math
import torch
import torch.nn as nn

class MultiFeaturePrototypeEMA(nn.Module):
    def __init__(self, cfg, feature_dims, max_range=50.0, bin_size=2.0, resume_path=None):
        super(MultiFeaturePrototypeEMA, self).__init__()
        self.cfg = cfg
        self.feature_dims = feature_dims 
        self.feature_keys = list(feature_dims.keys())
        self.max_range = max_range
        self.bin_size = bin_size
        self.momentum = cfg.PROTOTYPE.PROTOTYPE_EMA_STEPLR  # 예: 0.9
        self.voxel_size = cfg.HYPERPARAMETER.VOXEL_SIZE       # 예: 0.05
        self.num_bins = int(math.ceil(self.max_range / self.bin_size))
        
        device = torch.device('cuda')
        self.bin_stats = {}
        for key in self.feature_keys:
            dim = feature_dims[key]
            self.bin_stats[key] = {
                'mean': torch.zeros(self.num_bins, dim, device=device),
                'std': torch.zeros(self.num_bins, dim, device=device),
                'count': torch.zeros(self.num_bins, dtype=torch.int, device=device)
            }
        if resume_path:
            self.load(resume_path)
    
    def update(self, features_dict):
        for key in self.feature_keys:
            coords = features_dict[key].C  # shape: (N, D)
            feats = features_dict[key].F   # shape: (N, feature_dim)
            spatial_coords = coords[:, 1:].float() * self.voxel_size  # shape: (N, D-1)
            distances = torch.norm(spatial_coords, dim=1)  # (N,)
            valid_mask = distances < self.max_range
            if valid_mask.sum() == 0:
                continue
            valid_distances = distances[valid_mask]
            valid_feats = feats[valid_mask]  # (N_valid, feature_dim)
            bin_idx = (valid_distances / self.bin_size).floor().long()  # (N_valid,)
            # loop over unique bins in the current batch for this key
            unique_bins = bin_idx.unique()
            for b in unique_bins:
                b_item = b.item()  # 정수 bin index
                mask_bin = (bin_idx == b)
                feats_in_bin = valid_feats[mask_bin]  # (M, feature_dim)
                if feats_in_bin.size(0) == 0:
                    continue
                new_mean = feats_in_bin.mean(dim=0)
                new_std = feats_in_bin.std(dim=0, unbiased=False)
                new_count = feats_in_bin.size(0)
                if self.bin_stats[key]['count'][b_item] == 0:
                    self.bin_stats[key]['mean'][b_item] = new_mean
                    self.bin_stats[key]['std'][b_item] = new_std
                    self.bin_stats[key]['count'][b_item] = new_count
                else:
                    old_mean = self.bin_stats[key]['mean'][b_item]
                    old_std = self.bin_stats[key]['std'][b_item]
                    updated_mean = self.momentum * old_mean + (1 - self.momentum) * new_mean
                    updated_std = self.momentum * old_std + (1 - self.momentum) * new_std
                    self.bin_stats[key]['mean'][b_item] = updated_mean
                    self.bin_stats[key]['std'][b_item] = updated_std
                    self.bin_stats[key]['count'][b_item] += new_count

    def get_prototypes(self):
        return self.bin_stats
    
    def save(self, path):
        torch.save(self.bin_stats, path)
    
    def load(self, path, device='cpu'):
        self.bin_stats = torch.load(path, map_location=device)
        print(f"Loaded multi-feature EMA prototypes from {path}")

if __name__ == "__main__":
    class Config:
        class PROTOTYPE:
            PROTOTYPE_EMA_STEPLR = 0.9
        class HYPERPARAMETER:
            VOXEL_SIZE = 0.05
    cfg = Config()
    feature_dims = {"feat1": 64, "feat2": 128}
    estimator = MultiFeaturePrototypeEMA(cfg, feature_dims, max_range=50.0, bin_size=2.0)
    
    N = 1000
    feat1 = torch.randn(N, feature_dims["feat1"])
    feat2 = torch.randn(N, feature_dims["feat2"])
    # 좌표: [batch_id, x, y, z] 형태, 여기서는 batch id는 모두 0, x,y,z는 -50 ~ 50m 범위에서 무작위 생성
    batch_ids = torch.zeros(N, 1)
    coords = torch.cat([batch_ids, torch.rand(N, 3) * 100 - 50], dim=1)
    
    class DummySparseTensor:
        def __init__(self, F, C):
            self.F = F
            self.C = C
    
    features_dict = {
        "feat1": DummySparseTensor(feat1, coords),
        "feat2": DummySparseTensor(feat2, coords)
    }
    
    estimator.update(features_dict)
    
    prototypes = estimator.get_prototypes()
    for key, stats in prototypes.items():
        print(f"Feature: {key}")
        print(f"Mean matrix (shape {stats['mean'].shape}):")
        print(stats['mean'])
        print(f"Std matrix (shape {stats['std'].shape}):")
        print(stats['std'])
        print(f"Count vector (shape {stats['count'].shape}):")
        print(stats['count'])
