import torch
from torch import nn


electrode_list = [
 "A1",
 "A2",
 "C3",
 "C4",
 "CZ",
 "F3",
 "F4",
 "F7",
 "F8",
 "FP1",
 "FP2",
 "FZ",
 "O1",
 "O2",
 "P3",
 "P4",
 "PZ",
 "T3",
 "T4",
 "T5",
 "T6",
 "OZ",
 "CZ",
 "FPZ",
 "FZ",
 "P7",
 "P8",
 "PZ",
 "T7",
 "T8",
 "C1",
 "C2",
 "C5",
 "C6",
 "CP1",
 "CP2",
 "CP3",
 "CP4",
 "CPZ",
 "FC1",
 "FC2",
 "FC3",
 "FC4",
 "FCZ",
 "P1",
 "P2",
 "POZ",
 "AFZ",
 "P5",
 "P6",
 "PO3",
 "PO4",
 "AF3",
 "AF4",
 "AF7",
 "AF8",
 "CP5",
 "CP6",
 "F1",
 "F2",
 "F5",
 "F6",
 "FC5",
 "FC6",
 "FT7",
 "FT8",
 "FP1",
 "FP2",
 "IZ",
 "OZ",
 "P10",
 "P9",
 "PO7",
 "PO8",
 "TP7",
 "TP8",
 "F10",
 "F9",
 "FT10",
 "FT9",
 "FTT10H",
 "FTT9H",
 "PO10",
 "PO9",
 "TP10",
 "TP9",
 "TPP10H",
 "TPP8H",
 "TPP9H",
 "TTP7H",
 "CCP1H",
 "CCP2H",
 "CCP3H",
 "CCP4H",
 "CCP5H",
 "CCP6H",
 "CPP1H",
 "CPP2H",
 "CPP3H",
 "CPP4H",
 "CPP5H",
 "CPP6H",
 "FCC1H",
 "FCC2H",
 "FCC3H",
 "FCC4H",
 "FCC5H",
 "FCC6H",
 "FFC1H",
 "FFC2H",
 "FFC3H",
 "FFC4H",
 "FFC5H",
 "FFC6H",
 "FTT7H",
 "FTT8H",
 "PPO1H",
 "PPO2H",
 "TTP8H",
 "T10",
 "T9",
 "AFF1H",
 "AFF2H",
 "AFF5H",
 "AFF6H",
 "AFP1",
 "AFP2",
 "POO1",
 "POO2",
 "PO5",
 "PO6",
 "AFF1",
 "AFF2",
 "AFP3H",
 "AFP4H",
 "FFT7H",
 "FFT8H",
 "I1",
 "I2",
 "M1",
 "M2",
 "OI1H",
 "OI2H",
 "POO10H",
 "POO3H",
 "POO4H",
 "POO9H",
 "PPO1",
 "PPO10H",
 "PPO2",
 "PPO5H",
 "PPO6H",
 "PPO9H",
 "TPP7H",
 "A3",
 "A4",
 "A5",
 "A6",
 "A7",
 "A8",
 "A9",
 "A10",
 "A11",
 "A12",
 "A13",
 "A14",
 "A15",
 "A16",
 "A17",
 "A18",
 "A19",
 "A20",
 "A21",
 "A22",
 "A23",
 "A24",
 "A25",
 "A26",
 "A27",
 "A28",
 "A29",
 "A30",
 "A31",
 "A32",
 "B1",
 "B2",
 "B3",
 "B4",
 "B5",
 "B6",
 "B7",
 "B8",
 "B9",
 "B10",
 "B11",
 "B12",
 "B13",
 "B14",
 "B15",
 "B16",
 "B17",
 "B18",
 "B19",
 "B20",
 "B21",
 "B22",
 "B23",
 "B24",
 "B25",
 "B26",
 "B27",
 "B28",
 "B29",
 "B30",
 "B31",
 "B32",
 "C7",
 "C8",
 "C9",
 "C10",
 "C11",
 "C12",
 "C13",
 "C14",
 "C15",
 "C16",
 "C17",
 "C18",
 "C19",
 "C20",
 "C21",
 "C22",
 "C23",
 "C24",
 "C25",
 "C26",
 "C27",
 "C28",
 "C29",
 "C30",
 "C31",
 "C32",
 "D1",
 "D2",
 "D3",
 "D4",
 "D5",
 "D6",
 "D7",
 "D8",
 "D9",
 "D10",
 "D11",
 "D12",
 "D13",
 "D14",
 "D15",
 "D16",
 "D17",
 "D18",
 "D19",
 "D20",
 "D21",
 "D22",
 "D23",
 "D24",
 "D25",
 "D26",
 "D27",
 "D28",
 "D29",
 "D30",
 "D31",
 "D32",
 "CPZ",
 "FCZ",
 "FPZ",
 "POZ",
 "E1",
 "E10",
 "E100",
 "E101",
 "E102",
 "E103",
 "E104",
 "E105",
 "E106",
 "E107",
 "E108",
 "E109",
 "E11",
 "E110",
 "E111",
 "E112",
 "E113",
 "E114",
 "E115",
 "E116",
 "E117",
 "E118",
 "E119",
 "E12",
 "E120",
 "E121",
 "E122",
 "E123",
 "E124",
 "E125",
 "E126",
 "E127",
 "E128",
 "E13",
 "E14",
 "E15",
 "E16",
 "E17",
 "E18",
 "E19",
 "E2",
 "E20",
 "E21",
 "E22",
 "E23",
 "E24",
 "E25",
 "E26",
 "E27",
 "E28",
 "E29",
 "E3",
 "E30",
 "E31",
 "E32",
 "E33",
 "E34",
 "E35",
 "E36",
 "E37",
 "E38",
 "E39",
 "E4",
 "E40",
 "E41",
 "E42",
 "E43",
 "E44",
 "E45",
 "E46",
 "E47",
 "E48",
 "E49",
 "E5",
 "E50",
 "E51",
 "E52",
 "E53",
 "E54",
 "E55",
 "E56",
 "E57",
 "E58",
 "E59",
 "E6",
 "E60",
 "E61",
 "E62",
 "E63",
 "E64",
 "E65",
 "E66",
 "E67",
 "E68",
 "E69",
 "E7",
 "E70",
 "E71",
 "E72",
 "E73",
 "E74",
 "E75",
 "E76",
 "E77",
 "E78",
 "E79",
 "E8",
 "E80",
 "E81",
 "E82",
 "E83",
 "E84",
 "E85",
 "E86",
 "E87",
 "E88",
 "E89",
 "E9",
 "E90",
 "E91",
 "E92",
 "E93",
 "E94",
 "E95",
 "E96",
 "E97",
 "E98",
 "E99",
 "E129",
 "CP1",
 "CP2",
 "CP5",
 "CP6",
 "FC1",
 "FC2",
 "FC5",
 "FC6",
 "BIOSEMI128_A1",
 "BIOSEMI128_A2",
 "BIOSEMI128_A3",
 "BIOSEMI128_A4",
 "BIOSEMI128_A5",
 "BIOSEMI128_A6",
 "BIOSEMI128_A7",
 "BIOSEMI128_A8",
 "BIOSEMI128_A9",
 "BIOSEMI128_A10",
 "BIOSEMI128_A11",
 "BIOSEMI128_A12",
 "BIOSEMI128_A13",
 "BIOSEMI128_A14",
 "BIOSEMI128_A15",
 "BIOSEMI128_A16",
 "BIOSEMI128_A17",
 "BIOSEMI128_A18",
 "BIOSEMI128_A19",
 "BIOSEMI128_A20",
 "BIOSEMI128_A21",
 "BIOSEMI128_A22",
 "BIOSEMI128_A23",
 "BIOSEMI128_A24",
 "BIOSEMI128_A25",
 "BIOSEMI128_A26",
 "BIOSEMI128_A27",
 "BIOSEMI128_A28",
 "BIOSEMI128_A29",
 "BIOSEMI128_A30",
 "BIOSEMI128_A31",
 "BIOSEMI128_A32",
 "BIOSEMI128_B1",
 "BIOSEMI128_B2",
 "BIOSEMI128_B3",
 "BIOSEMI128_B4",
 "BIOSEMI128_B5",
 "BIOSEMI128_B6",
 "BIOSEMI128_B7",
 "BIOSEMI128_B8",
 "BIOSEMI128_B9",
 "BIOSEMI128_B10",
 "BIOSEMI128_B11",
 "BIOSEMI128_B12",
 "BIOSEMI128_B13",
 "BIOSEMI128_B14",
 "BIOSEMI128_B15",
 "BIOSEMI128_B16",
 "BIOSEMI128_B17",
 "BIOSEMI128_B18",
 "BIOSEMI128_B19",
 "BIOSEMI128_B20",
 "BIOSEMI128_B21",
 "BIOSEMI128_B22",
 "BIOSEMI128_B23",
 "BIOSEMI128_B24",
 "BIOSEMI128_B25",
 "BIOSEMI128_B26",
 "BIOSEMI128_B27",
 "BIOSEMI128_B28",
 "BIOSEMI128_B29",
 "BIOSEMI128_B30",
 "BIOSEMI128_B31",
 "BIOSEMI128_B32",
 "BIOSEMI128_C1",
 "BIOSEMI128_C2",
 "BIOSEMI128_C3",
 "BIOSEMI128_C4",
 "BIOSEMI128_C5",
 "BIOSEMI128_C6",
 "BIOSEMI128_C7",
 "BIOSEMI128_C8",
 "BIOSEMI128_C9",
 "BIOSEMI128_C10",
 "BIOSEMI128_C11",
 "BIOSEMI128_C12",
 "BIOSEMI128_C13",
 "BIOSEMI128_C14",
 "BIOSEMI128_C15",
 "BIOSEMI128_C16",
 "BIOSEMI128_C17",
 "BIOSEMI128_C18",
 "BIOSEMI128_C19",
 "BIOSEMI128_C20",
 "BIOSEMI128_C21",
 "BIOSEMI128_C22",
 "BIOSEMI128_C23",
 "BIOSEMI128_C24",
 "BIOSEMI128_C25",
 "BIOSEMI128_C26",
 "BIOSEMI128_C27",
 "BIOSEMI128_C28",
 "BIOSEMI128_C29",
 "BIOSEMI128_C30",
 "BIOSEMI128_C31",
 "BIOSEMI128_C32",
 "BIOSEMI128_D1",
 "BIOSEMI128_D2",
 "BIOSEMI128_D3",
 "BIOSEMI128_D4",
 "BIOSEMI128_D5",
 "BIOSEMI128_D6",
 "BIOSEMI128_D7",
 "BIOSEMI128_D8",
 "BIOSEMI128_D9",
 "BIOSEMI128_D10",
 "BIOSEMI128_D11",
 "BIOSEMI128_D12",
 "BIOSEMI128_D13",
 "BIOSEMI128_D14",
 "BIOSEMI128_D15",
 "BIOSEMI128_D16",
 "BIOSEMI128_D17",
 "BIOSEMI128_D18",
 "BIOSEMI128_D19",
 "BIOSEMI128_D20",
 "BIOSEMI128_D21",
 "BIOSEMI128_D22",
 "BIOSEMI128_D23",
 "BIOSEMI128_D24",
 "BIOSEMI128_D25",
 "BIOSEMI128_D26",
 "BIOSEMI128_D27",
 "BIOSEMI128_D28",
 "BIOSEMI128_D29",
 "BIOSEMI128_D30",
 "BIOSEMI128_D31",
 "BIOSEMI128_D32",
]


class RevePositionBank(nn.Module):
 def __init__(self):
 super().__init__()

 self.position_names = electrode_list
 self.mapping = {name.upper(): i for i, name in enumerate(self.position_names)}
 self.register_buffer("embedding", torch.randn(len(self.position_names), 3))

 def forward(self, channel_names: list[str]):
 # normalize and apply replacements
 channel_names = [cn.upper() for cn in channel_names]
 channel_names = ["TP7" if cn == "T1" else "TP8" if cn == "T2" else cn for cn in channel_names]

 indices = [self.mapping[q] for q in channel_names if q in self.mapping]

 if len(indices) < len(channel_names):
 print(f"Found {len(indices)} positions out of {len(channel_names)} channels")

 indices = torch.tensor(indices, device=self.embedding.device)

 return self.embedding[indices]

 def get_all_positions(self):
 return self.position_names
