from typing import List, Tuple


from batlinet.builders import TRAIN_TEST_SPLITTERS
from batlinet.train_test_split.base import BaseTrainTestSplitter


@TRAIN_TEST_SPLITTERS.register()
class MIX20TrainTestSplitter(BaseTrainTestSplitter):
    def __init__(self, cell_data_path: str):
        BaseTrainTestSplitter.__init__(self, cell_data_path)

        # Randomly permutation of cells with seed 0.
        # We find that the randomness may cause inconsistency between machines. For others to reproduce, we explicitly define the split. You can always define your own split if needed.
        test_ids = [
            "RWTH_011",
            "HUST_9-1",
            "MATR_b2c31",
            "CALCE_CS2_36",
            "MATR_b3c24",
            "MATR_b3c45",
            "RWTH_032",
            "HUST_3-8",
            "MATR_b4c36",
            "RWTH_049",
            "HUST_8-3",
            "HUST_10-8",
            "MATR_b4c23",
            "SNL_18650_NCA_25C_20-80_0.5-0.5C_d",
            "CALCE_CX2_38",
            "RWTH_046",
            "HUST_6-4",
            "MATR_b1c11",
            "HUST_10-5",
            "UL-PUR_N20-NA6_18650_NCA_23C_0-100_0.5-0.5C_f",
            "MATR_b4c38",
            "HUST_7-6",
            "RWTH_013",
            "HUST_7-3",
            "RWTH_008",
            "MATR_b4c5",
            "HUST_2-3",
            "HUST_4-5",
            "MATR_b1c0",
            "HUST_6-3",
            "SNL_18650_NCA_25C_0-100_0.5-0.5C_a",
            "MATR_b1c12",
            "MATR_b1c39",
            "MATR_b2c26",
            "CALCE_CX2_16",
            "HUST_1-7",
            "HUST_9-6",
            "MATR_b3c2",
            "MATR_b1c3",
            "RWTH_025",
            "HUST_1-3",
            "RWTH_031",
            "MATR_b4c44",
            "HUST_2-6",
            "MATR_b2c24",
            "MATR_b2c32",
            "MATR_b4c27",
            "SNL_18650_NCA_25C_20-80_0.5-0.5C_a",
            "MATR_b3c7",
            "MATR_b2c33",
            "MATR_b3c1",
            "MATR_b3c0",
            "HUST_6-2",
            "HUST_4-3",
            "RWTH_048",
            "SNL_18650_NCA_35C_0-100_0.5-1C_b",
            "UL-PUR_N15-OV3_18650_NCA_23C_0-100_0.5-0.5C_c",
            "RWTH_045",
            "RWTH_028",
            "MATR_b4c35",
            "HUST_4-7",
            "RWTH_034",
            "SNL_18650_LFP_25C_0-100_0.5-1C_d",
            "SNL_18650_LFP_25C_0-100_0.5-3C_c",
            "SNL_18650_LFP_25C_0-100_0.5-1C_c",
            "MATR_b4c31",
            "MATR_b2c45",
            "MATR_b2c12",
            "MATR_b4c1",
            "HUST_4-6",
            "MATR_b1c42",
            "MATR_b3c8",
            "MATR_b3c17",
            "HUST_1-2",
            "HUST_9-4",
            "HUST_1-1",
            "MATR_b4c8",
            "HUST_10-6",
            "HUST_4-4",
            "UL-PUR_N20-EX2_18650_NCA_23C_0-100_0.5-0.5C_b",
            "MATR_b4c32",
            "HUST_8-8",
            "MATR_b1c32",
            "RWTH_012",
            "MATR_b1c44",
            "MATR_b2c39",
            "MATR_b2c14",
            "RWTH_027",
            "RWTH_005",
            "MATR_b2c22",
            "MATR_b4c6",
            "SNL_18650_LFP_35C_0-100_0.5-1C_c",
            "MATR_b2c23",
            "SNL_18650_LFP_35C_0-100_0.5-1C_b",
            "RWTH_039",
            "MATR_b2c41",
            "SNL_18650_LFP_15C_0-100_0.5-2C_b",
            "MATR_b3c15",
            "MATR_b1c6",
            "HUST_10-1",
            "CALCE_CX2_34",
            "HUST_2-8",
            "HUST_8-4",
            "MATR_b3c4",
            "RWTH_019",
            "CALCE_CX2_36",
            "MATR_b1c16",
            "RWTH_037",
            "MATR_b1c33",
            "MATR_b3c18",
            "MATR_b1c40",
            "MATR_b1c28",
            "HUST_6-6",
            "HUST_1-8",
            "MATR_b2c10",
            "MATR_b1c21",
            "MATR_b4c22",
            "RWTH_041",
            "MATR_b1c45",
            "MATR_b4c17",
            "MATR_b3c12",
            "HUST_3-3",
            "RWTH_044",
            "HUST_5-5",
            "RWTH_035",
            "MATR_b1c8",
            "SNL_18650_NCA_25C_20-80_0.5-0.5C_c",
            "MATR_b4c34",
            "HUST_7-5",
            "MATR_b4c11",
            "HUST_5-7",
            "RWTH_023",
            "MATR_b3c6",
            "SNL_18650_LFP_25C_0-100_0.5-3C_a",
            "MATR_b2c30",
            "MATR_b3c25",
            "MATR_b4c25",
            "MATR_b4c9",
            "MATR_b1c36",
            "MATR_b3c31",
            "MATR_b1c22",
            "MATR_b2c43",
            "MATR_b1c26",
            "RWTH_002",
            "MATR_b4c20",
            "HUST_8-7",
            "MATR_b2c11",
        ]

        self.train_cells, self.test_cells = [], []

        for filename in self._file_list:
            # filename like: HUST_1-1.pkl
            if filename.stem in test_ids:
                self.test_cells.append(filename)
            else:
                self.train_cells.append(filename)

    def split(self) -> Tuple[List, List]:
        return self.train_cells, self.test_cells
