import json
import logging
import os
import warnings
from dataclasses import dataclass, field
from datetime import datetime
from typing import Optional, Union, Any

import datasets
import mne
import numpy as np
import pandas as pd
import pytz
from mne.io import BaseRaw
from numpy import ndarray
from pandas import DataFrame

from common.type import DatasetTaskType
from data.processor.builder import EEGDatasetBuilder, EEGConfig


logger = logging.getLogger('preproc')


@dataclass
class SeedVConfig(EEGConfig):
 name: str = 'pretrain'
 version: Optional[Union[datasets.utils.Version, str]] = datasets.utils.Version("3.0.1")
 description: Optional[str] = (
 "The emotion induction method adopted in this experiment is stimulus material induction. "
 "In other words, it is generated by allowing the subjects to watch specific emotional "
 "stimulus materials to achieve the purpose of inducing the subjects' corresponding emotional state.")
 citation: Optional[str] = """\
 @ARTICLE{9395500,
 author={Liu, Wei and Qiu, Jie-Lin and Zheng, Wei-Long and Lu, Bao-Liang},
 journal={IEEE Transactions on Cognitive and Developmental Systems},
 title={Comparing Recognition Performance and Robustness of Multimodal Deep Learning Models for Multimodal Emotion Recognition},
 year={2022},
 volume={14},
 number={2},
 pages={715-729},
 keywords={Emotion recognition;Electroencephalography;Robustness;Deep learning;Correlation;Brain modeling;Computational modeling;Bimodal deep autoencoder (BDAE);deep canonical correlation analysis (DCCA);electroencephalography (EEG);eye movement;multimodal deep learning;multimodal emotion recognition;robustness},
 doi={10.1109/TCDS.2021.3071170}}
 """

 filter_notch: float = 50.0
 persist_drop_last: bool = False

 dataset_name: Optional[str] = "seed_v"
 task_type: DatasetTaskType = DatasetTaskType.EMOTION
 file_ext: str = "set"
 montage: dict[str, list[str]] = field(default_factory=lambda : {
 '10_10': [
 'FP1','FPZ','FP2',
 'AF3', 'AF4',
 'F7','F5','F3','F1','FZ','F2','F4','F6','F8',
 'FT7','FC5','FC3','FC1','FCZ','FC2','FC4','FC6','FT8',
 'T7','C5','C3','C1','CZ','C2','C4','C6','T8',
 'TP7','CP5','CP3','CP1','CPZ','CP2','CP4','CP6','TP8',
 'P7','P5','P3','P1','PZ','P2','P4','P6','P8',
 'PO7','PO5','PO3','POZ','PO4','PO6','PO8',
 'O1','OZ','O2'
 ]
 })

 valid_ratio: float = 0.125
 test_ratio: float = 0.125
 wnd_div_sec: int = 10
 suffix_path: str = os.path.join('SEED', 'SEED-V')
 scan_sub_dir: str = os.path.join("EEG_raw", 'resampled')

 category: list[str] = field(default_factory=lambda: ['Disgust', 'Fear', 'Sad', 'Neutral', 'Happy'])
 is_cross_subject: bool = True


class SeedVBuilder(EEGDatasetBuilder):
 BUILDER_CONFIG_CLASS = SeedVConfig
 BUILDER_CONFIGS = [
 BUILDER_CONFIG_CLASS(name='pretrain'),
 BUILDER_CONFIG_CLASS(name='finetune', is_finetune=True),
 BUILDER_CONFIG_CLASS(name='finetune_sub_dependent', is_finetune=True, is_cross_subject=False),
 ]

 def __init__(self, config_name='pretrain',**kwargs):
 super().__init__(config_name, **kwargs)
 self._load_meta_info()

 def _load_meta_info(self):
 # meta_path = os.path.join(self.config.raw_path, 'meta')
 sub_info = {
 'name': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,],
 'sex': ['F', 'M', 'F', 'F', 'M', 'F', 'F', 'M', 'F', 'M', 'F', 'F', 'F', 'M', 'M', 'F',],
 'age': [ 21, 21, 23, 21, 20, 23, 23, 21, 23, 24, 24, 20, 22, 22, 19, 19,]
 }

 label_info = np.array([
 [4, 1, 3, 2, 0, 4, 1, 3, 2, 0, 4, 1, 3, 2, 0],
 [2, 1, 3, 0, 4, 4, 0, 3, 2, 1, 3, 4, 1, 2, 0],
 [2, 1, 3, 0, 4, 4, 0, 3, 2, 1, 3, 4, 1, 2, 0],
 ], dtype=np.int32)

 time_info = np.array([
 [[30 , 132, 287, 555, 773, 982 , 1271, 1628, 1730, 2025, 2227, 2435, 2667, 2932, 3204],
 [102, 228, 524, 742, 920, 1240, 1568, 1697, 1994, 2166, 2401, 2607, 2901, 3172, 3359]],
 [[30 , 299, 548, 646, 836, 1000, 1091, 1392, 1657, 1809, 1966, 2186, 2333, 2490, 2741],
 [267, 488, 614, 773, 967, 1059, 1331, 1622, 1777, 1908, 2153, 2302, 2428, 2709, 2817]],
 [[30 , 353, 478, 674, 825, 908 , 1200, 1346, 1451, 1711, 2055, 2307, 2457, 2726, 2888],
 [321, 418, 643, 764, 877, 1147, 1284, 1418, 1679, 1996, 2275, 2425, 2664, 2857, 3066]],
 ], dtype=np.int32)

 self.sub_meta = pd.DataFrame(sub_info)
 self.time_meta: ndarray = time_info
 self.label_meta: ndarray = label_info

 def _resolve_file_name(self, file_path: str) -> dict[str, Any]:
 file_name = self._extract_file_name(file_path)
 subject, session, date = file_name.split('_')
 data_obj = datetime.strptime(date, "%Y%m%d")
 time_zone = pytz.timezone('Asia/Shanghai')
 data_obj = time_zone.localize(data_obj)
 return {
 'subject': int(subject),
 'session': int(session),
 'date': data_obj.strftime("%Y%m%d"),
 }

 def _resolve_exp_meta_info(self, file_path: str) -> dict[str, Any]:
 info = self._resolve_file_name(file_path)
 sex, age = self.sub_meta.loc[info['subject'] - 1, ['sex', 'age']]
 sex = 1 if sex == 'M' else 2
 with self._read_raw_data(file_path, preload=False, verbose=False) as raw:
 time = raw.duration

 info.update({
 'montage': '10_10',
 'time': time,
 'sex': sex,
 'age': age,
 })
 return info

 def _resolve_exp_events(self, file_path: str, info: dict[str, Any]):
 if not self.config.is_finetune:
 return [('default', 0, -1)]

 session = info['session']
 labels = self.label_meta[session - 1]
 times = self.time_meta[session - 1]

 assert len(labels) == times.shape[1]

 annotations = []
 for i in range(len(labels)):
 annotations.append((
 self.config.category[labels[i]],
 times[0][i].item() * 1000,
 times[1][i].item() * 1000,
 ))
 return annotations

 def _divide_split(self, df: DataFrame) -> DataFrame:
 if self.config.is_cross_subject:
 return self._divide_all_split_by_sub(df)
 else:
 return self._divide_by_uniform_label(df)

 @staticmethod
 def _divide_by_uniform_label(df: DataFrame) -> DataFrame:
 def split_row(row):
 labels = json.loads(row['label'])
 label_train, label_valid, label_test = labels[0:5], labels[5:10], labels[10:15]
 new_rows = []
 for set_name, label in zip(['train', 'valid', 'test'], [label_train, label_valid, label_test]):
 new_row = row.copy()
 new_row['split'] = set_name
 new_row['label'] = json.dumps(label)
 new_rows.append(new_row)
 return pd.DataFrame(new_rows)

 df['split'] = 'train'
 processed = df.apply(split_row, axis=1).tolist()
 new_df = pd.concat(processed).reset_index(drop=True)
 return new_df

 def standardize_chs_names(self, montage: str):
 return self.config.montage[montage]

 def _walk_raw_data_files(self):
 scan_path = os.path.join(self.config.raw_path, self.config.scan_sub_dir)
 raw_data_files = []
 for root, dirs, files in os.walk(scan_path):
 for file in files:
 # if '7_1_20180411' in file:
 # continue
 if file.endswith(self.config.file_ext):
 file_path = os.path.join(root, file)
 raw_data_files.append(os.path.normpath(file_path))
 return raw_data_files

 def _read_raw_data(self, file_path: str, preload: bool = False, verbose: bool = False) -> BaseRaw:
 with warnings.catch_warnings():
 warnings.filterwarnings(
 "ignore",
 category=RuntimeWarning,
 )
 data = mne.io.read_raw_eeglab(file_path, preload=preload, verbose=verbose)
 return data


if __name__ == "__main__":
 builder = SeedVBuilder(config_name='finetune_sub_dependent')
 builder.preproc(n_proc=1)
 builder.download_and_prepare(num_proc=1)
 dataset = builder.as_dataset()
 print(dataset)



