import random
from enum import Enum
from typing import Union

import numpy as np
import torch
import yaml
from omegaconf import DictConfig
from torch import nn, Tensor

from common.type import PriorType


def split_array_equal_part(arr: Union[list[int], Tensor] , k: int):
 return [arr[i:i + k] for i in range(0, len(arr), k)]


def update_conf_dict(conf: DictConfig, src: str, dst: str):
 conf = conf.copy()
 conf[dst] = conf[src]
 return conf


def set_seed(seed: int, deterministic: bool):
 random.seed(seed)
 np.random.seed(seed)
 torch.manual_seed(seed)
 torch.cuda.manual_seed_all(seed)

 if deterministic:
 torch.backends.cudnn.deterministic = True
 torch.backends.cudnn.benchmark = False


def get_num_params(model: nn.Module) -> int:
 """
 Get the total model params
 Args : only_trainable: whether to only count trainable params
 """
 param = {n: p.numel() for n, p in model.named_parameters()}
 return sum(param.values())


def yaml_enum_represent(dumper: yaml.Dumper, enum_obj: Enum) -> yaml.ScalarNode:
 return dumper.represent_scalar(
 tag='tag:yaml.org,2002:str',
 value=enum_obj.value
 )

def setup_yaml():
 yaml.representer.Representer.add_multi_representer(Enum, yaml_enum_represent)


class ElectrodeSet:
 """
 A class to represent a set of EEG electrodes based on the 10-10 system.

 """
 Layout = '10-10'
 Count = 90
 Electrodes = [
 'FP1', 'FPZ', 'FP2',
 'AF9', 'AF7', 'AF5', 'AF3', 'AF1', 'AFZ', 'AF2', 'AF4', 'AF6', 'AF8', 'AF10',
 'T1', 'F9', 'F7', 'F5', 'F3', 'F1', 'FZ', 'F2', 'F4', 'F6', 'F8', 'F10', 'T2',
 'FT9', 'FT7', 'FC5', 'FC3', 'FC1', 'FCZ', 'FC2', 'FC4', 'FC6', 'FT8', 'FT10',
 'A1', 'T9', 'T7', 'C5', 'C3', 'C1', 'CZ', 'C2', 'C4', 'C6', 'T8', 'T10', 'A2',
 'TP9', 'TP7', 'CP5', 'CP3', 'CP1', 'CPZ', 'CP2', 'CP4', 'CP6', 'TP8', 'TP10',
 'P9', 'P7', 'P5', 'P3', 'P1', 'PZ', 'P2', 'P4', 'P6', 'P8', 'P10',
 'PO9', 'PO7', 'PO5', 'PO3', 'PO1', 'POZ', 'PO2', 'PO4', 'PO6', 'PO8', 'PO10',
 'O1', 'OZ', 'O2',
 'I1', 'IZ', 'I2',
 ]

 Hybrid_Groups = {
 'frontal': [
 'FP1', 'FPZ', 'FP2',
 'AF9','AF7', 'AF5', 'AF3', 'AF1', 'AFZ', 'AF2', 'AF4', 'AF6', 'AF8', 'AF10',
 'F9', 'F7', 'F5', 'F3', 'F1', 'FZ', 'F2', 'F4', 'F6', 'F8', 'F10',
 ],
 'central': [
 'FC5', 'FC3', 'FC1', 'FCZ', 'FC2', 'FC4', 'FC6',
 'C5', 'C3', 'C1', 'CZ', 'C2', 'C4', 'C6',
 'CP5', 'CP3', 'CP1', 'CPZ', 'CP2', 'CP4', 'CP6',
 ],
 'temporal': [
 'FT9', 'FT7', 'T9', 'T7', 'TP9', 'TP7', 'T1',
 'FT10', 'FT8', 'T10', 'T8', 'TP10', 'TP8', 'T2',
 ],
 'parietal': [
 'CP3', 'CP1', 'CPZ', 'CP2', 'CP4',
 'P5', 'P3', 'P1', 'PZ', 'P2', 'P4', 'P6',
 ],
 'occipital': [
 'P9', 'P10',
 'PO9', 'PO7', 'PO5', 'PO3', 'PO1', 'POZ', 'PO2', 'PO4', 'PO6', 'PO8', 'PO10',
 'O1', 'OZ', 'O2',
 'I1', 'IZ', 'I2',
 ],
 'dmn': ['FPZ', 'AFZ', 'FZ', 'FCZ', 'CZ', 'CPZ', 'PZ', 'POZ', 'OZ', 'IZ', 'F3', 'F4', 'P3', 'P4', 'P5', 'P6', 'P7', 'P8', 'TP7', 'TP8'],
 'ecn': ['F1', 'F2', 'F3', 'F4', 'F5', 'F6', 'AF3', 'AF4', 'FC1', 'FC2', 'FC3', 'FC4', 'P1', 'P2', 'P3', 'P4', 'CP1', 'CP2', 'CP3', 'CP4'],
 'sn': ['FZ', 'FCZ', 'CZ', 'F1', 'F2', 'FC1', 'FC2', 'C1', 'C2', 'FC3', 'FC4', 'C3', 'C4']
 }

 Region_Groups = {
 'prefrontal': [
 'FP1', 'FPZ', 'FP2',
 'AF9', 'AF7', 'AF5', 'AF3', 'AF1', 'AFZ', 'AF2', 'AF4', 'AF6', 'AF8', 'AF10',
 ],
 'frontal': [
 'F9', 'F7', 'F5', 'F3', 'F1', 'FZ', 'F2', 'F4', 'F6', 'F8', 'F10',
 'FT9', 'FT7', 'FC5', 'FC3', 'FC1', 'FCZ', 'FC2', 'FC4', 'FC6', 'FT8', 'FT10',
 ],
 'central': [
 'FC1', 'FCZ', 'FC2',
 'C5', 'C3', 'C1', 'CZ', 'C2', 'C4', 'C6',
 'CP1', 'CPZ', 'CP2',
 ],
 'parietal': [
 'TP9', 'TP7', 'CP5', 'CP3', 'CP1', 'CPZ', 'CP2', 'CP4', 'CP6', 'TP8', 'TP10',
 'P9', 'P7', 'P5', 'P3', 'P1', 'PZ', 'P2', 'P4', 'P6', 'P8', 'P10',
 ],
 'occipital': [
 'PO9', 'PO7', 'PO5', 'PO3', 'PO1', 'POZ', 'PO2', 'PO4', 'PO6', 'PO8', 'PO10',
 'O1', 'OZ', 'O2',
 'I1', 'IZ', 'I2',
 ],
 'temporal-left': ['T1', 'T7', 'T9', 'FT7', 'FT9', 'TP7', 'TP9'],
 'temporal-right': ['T2', 'T8', 'T10', 'FT8', 'FT10', 'TP8', 'TP10'],
 'midline': ['FPZ', 'AFZ', 'FZ', 'FCZ', 'CZ', 'CPZ', 'PZ', 'POZ', 'OZ', 'IZ']
 }

 Network_Groups = {
 'dmn': ['FP1', 'FPZ', 'FP2', 'AFZ', 'FZ', 'FCZ', 'CPZ', 'PZ', 'P5', 'P7', 'P6', 'P8', 'TP7', 'TP8', 'AF1', 'AF2'],
 'ecn': ['AF1', 'AF2', 'AF3', 'AF4', 'AF5', 'AF6', 'AF7', 'AF8', 'AF9', 'AF10', 'F1', 'F3', 'F5', 'F2', 'F4', 'F6', 'FC1', 'FC3', 'FC2', 'FC4', 'P1', 'P3', 'P2', 'P4', 'CP1', 'CP3', 'CP2', 'CP4'],
 'sn': ['FZ', 'FCZ', 'CZ', 'F1', 'F2', 'FC1', 'FC2', 'C1', 'C2', 'FC3', 'FC4'],
 'dan': ['F1', 'F3', 'F2', 'F4', 'P1', 'P3', 'P5', 'P2', 'P4', 'P6', 'PO3', 'PO4'],
 'van': ['F10', 'F8', 'FT10', 'FT8', 'T8', 'TP8', 'C6', 'CP6', 'P10', 'P8', 'P6', 'T10', 'TP10'],
 'visual': ['P9', 'P10', 'PO9', 'PO7', 'PO5', 'PO3', 'PO1', 'POZ', 'PO2', 'PO4', 'PO6', 'PO8', 'PO10', 'O1', 'OZ', 'O2', 'I1', 'IZ', 'I2'],
 'somatomotor': ['FC5', 'FC3', 'FC1', 'FC2', 'FC4', 'FC6', 'C5', 'C3', 'C1', 'CZ', 'C2', 'C4', 'C6', 'CP5', 'CP3', 'CP1', 'CP2', 'CP4', 'CP6'],
 'language': ['F9', 'F7', 'F5', 'FT9', 'FT7', 'T7', 'TP7', 'C5', 'CP5', 'P9', 'P7', 'P5', 'T9', 'TP9']
 }

 def __init__(self):
 self.electrode_dict = {electrode: i for i, electrode in enumerate(self.Electrodes)}
 self.index_dict = {i: electrode for i, electrode in enumerate(self.Electrodes)}
 all_groups = {**self.Region_Groups, **self.Network_Groups}
 self.prior_matrix_dict = {
 'hybrid': self._create_boolean_matrix(self.Hybrid_Groups),
 'region': self._create_boolean_matrix(self.Region_Groups),
 'network': self._create_boolean_matrix(self.Network_Groups),
 'all': self._create_boolean_matrix(all_groups)
 }

 def __len__(self):
 return self.Count

 def get_electrodes_index(self, electrodes: list[str]) -> np.ndarray:
 return np.array([self.electrode_dict[electrode] for electrode in electrodes], dtype=np.int32)

 def get_electrodes_name(self, electrodes: list[int]) -> list[str]:
 return [self.index_dict[electrode] for electrode in electrodes]

 def get_prior_matrix(self, sel: PriorType) -> np.ndarray:
 return self.prior_matrix_dict[sel.value]

 def _create_boolean_matrix(self, input_dict):
 element_to_index = {element: idx for idx, element in enumerate(self.Electrodes)}

 bool_matrix = []

 for key, elements in input_dict.items():
 bool_list = [False] * len(self.Electrodes)

 for element in elements:
 if element in element_to_index:
 idx = element_to_index[element]
 bool_list[idx] = True

 bool_matrix.append(bool_list)

 bool_matrix = np.array(bool_matrix, dtype=np.bool)
 return bool_matrix

