import numpy as np


class EEGFeatureStructure:
    def __init__(self, args, n_channels=6):
        self.n_channels = n_channels
        self.args = args
        if self.args.prototype == 'prototype1':
            self.feature_groups = {
                'time_stats': self._get_mask(0, 4),  # 时域
                'hjorth': self._get_mask(4, 6),  # 时域
                'psd': self._get_mask(6, 11),  # 频域
                'wavelet': self._get_mask(11, 15),  # 时频
            }
            self.domain_weights = {
                'time_stats': 0.9,  # 时域统计权重
                'hjorth': 1.0,  # Hjorth参数权重
                'psd': 1.5,  # 频域PSD权重（更高）
                'wavelet': 1.2}  # 小波分量权重}

        elif self.args.prototype == 'prototype2':
            self.feature_groups = {
                'time_stats': self._get_mask(0, 6),  # 时域
                'psd': self._get_mask(6, 11),  # 频域
                'wavelet': self._get_mask(11, 15),  # 时频
                'sleep': self._get_mask(15, 19)  # 睡眠特征
            }
            self.domain_weights = {
                'time_stats': 0.9,  # 时域
                'psd': 1.3,  # 频域
                'wavelet': 1.35,  # 时频
                'sleep': 0.9  # 睡眠
            }
        elif self.args.prototype == 'prototype3':
            self.feature_groups = {
                'time_stats': self._get_mask(0, 6),  # 时域
                'psd': self._get_mask(6, 11),  # 频域
                'wavelet': self._get_mask(11, 15),  # 时频
            }
            self.domain_weights = {
                'time_stats': 0.9,  # 时域
                'psd': 1.5,  # 频域
                'wavelet': 1.2  # 时频
            }

    def _get_mask(self, start_feat, end_feat):

        return np.concatenate([np.arange(c * 15 + start_feat, c * 15 + end_feat)
                               for c in range(self.n_channels)])

    def flatten_to_vector(self, feature_matrix):

        return feature_matrix.reshape(-1)

    def group_features(self, feature_vector):

        return {domain: feature_vector[mask]
                for domain, mask in self.feature_groups.items()}


class PrototypeSimilarityCalculator:
    def __init__(self, feature_structure):
        self.fs = feature_structure

    def domain_weighted_cosine(self, x, y):
        x_groups = self.fs.group_features(x)
        y_groups = self.fs.group_features(y)

        total_sim = 0.0
        for domain in self.fs.feature_groups:
            vec_x = x_groups[domain]
            vec_y = y_groups[domain]
            dot = np.dot(vec_x, vec_y)
            norm = np.linalg.norm(vec_x) * np.linalg.norm(vec_y)
            sim = dot / (norm + 1e-8)
            total_sim += sim * self.fs.domain_weights[domain]

        return total_sim / sum(self.fs.domain_weights.values())


def get_SimilarityCalculator(args, n_channels):
    feature_structure = EEGFeatureStructure(args=args, n_channels=n_channels)
    return PrototypeSimilarityCalculator(feature_structure)




