from collections import Counter

import numpy as np
from sklearn.base import clone, is_classifier
from sklearn.ensemble._base import BaseEnsemble
from tqdm import tqdm
import pandas as pd
import numpy as np
import os
import random
import warnings
from sklearn.model_selection import StratifiedKFold
from aif360.sklearn.metrics import generalized_entropy_error
from fairlearn.metrics import demographic_parity_difference, equalized_odds_difference
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import accuracy_score, recall_score, f1_score, precision_score
from sklearn.neural_network import MLPClassifier
import argparse
import time
parser = argparse.ArgumentParser(description='FairGBFC')
parser.add_argument('--sen_index', type=int, default=1, help='index of sensitive attribute')
args = parser.parse_args()
warnings.filterwarnings("ignore")


def mask_to_idx(mask: np.ndarray):
    return np.where(mask)[0]


def idx_to_mask(indices: np.ndarray, size: int):
    mask = np.zeros(size, dtype=bool)
    mask[indices] = True
    return mask


def dict_info(d):
    info = ''
    for k, v in d.items():
        info += f'{k}: {v}\n'
    return info


def seed_generator(master_seed):
    """
    Generator function to produce random seeds based on a master seed.

    Parameters:
        master_seed (int): The master seed used to initialize the random number generator.

    Yields:
        int: A random seed generated using the master_seed.
    """
    rng = np.random.default_rng(master_seed)
    while True:
        yield rng.integers(np.iinfo(np.uint32).max)



class FairEnsemble(BaseEnsemble):
    default_how = {
        'pred': 'calib',
        'consensus': 'disag',  # 'disag', 'other-disag',
        'uniformity': 0.1,
        'drop_ratio': 1.0,
        'pos_ratio': 'max',
        'bootstrap': False,
        'action': 'drop',
    }
    default_fair_how = {
        'method': None,
        'constraint': None,
    }

    def __init__(
        self,
        estimator=None,
        *,
        n_estimators=10,
        fair_how=None,
        how=None,
        verbose=False,
        bootstrap=False,
        random_state=None,
    ):
        assert is_classifier(estimator), "Base estimator must be a classifier"
        self.fair_how = self.default_fair_how.copy()
        if fair_how is not None:
            self.fair_how.update(fair_how)
        # validate estimator
        fair_how = self.fair_how
        self.init_estimator = estimator
        if fair_how['method'] is None:
            self.estimator = estimator
        else:
            pass
        self.base_estimator = None
        self.estimator_params = {}
        self._validate_estimator()
        self.flag_proba = hasattr(self.estimator_, 'predict_proba')
        # store parameters
        self.n_estimators = n_estimators
        self.how = self.default_how.copy()
        self.how.update(how)
        self.verbose = verbose
        self.bootstrap = bootstrap
        # Master random seed
        self.random_state = random_state
        # Seeds for each estimator
        self.seed_generator = seed_generator(random_state)
        self.random_seeds = [
            next(self.seed_generator) for _ in range(self.n_estimators)
        ]

    def _init_stats(self, X, y, s):
        y_count, g_count = Counter(y), Counter(s)
        groups = list(g_count.keys())
        assert (
            set(y_count.keys()) == set(g_count.keys()) == set([0, 1])
        ), "Label (y) and sensitive attribute (s) values should be in \{0, 1\}."
        data_pos_ratio = (y == 1).sum() / len(y)
        avg_group_size = len(y) / len(g_count)
        stats_grp = {}
        for g in groups:
            mask_g = s == g
            y_g = y[mask_g]
            g_pos_ratio = (y_g == 1).sum() / g_count[g]
            stats_grp[g] = {
                'size': g_count[g],
                'n_pos': (y_g == 1).sum(),
                'n_neg': (y_g == 0).sum(),
                'pos_ratio': g_pos_ratio,
                'neg_ratio': 1 - g_pos_ratio,
                'idx': mask_to_idx(mask_g),  # global index
                'idx_pos': mask_to_idx(
                    (y == 1) & mask_g
                ),  # global index of positive samples
                'idx_neg': mask_to_idx(
                    (y == 0) & mask_g
                ),  # global index of negative samples
                # group type: advantaged or disadvantaged
                'type_group': 'adv' if g_pos_ratio > data_pos_ratio else 'dis',
                # group size: majority or minority
                'type_size': 'maj' if g_count[g] > avg_group_size else 'min',
            }
        self.stats_grp = stats_grp
        self.stats_meta = {
            'pos_ratio': data_pos_ratio,
            # group label with max/min positive ratio
            'g_max_pos_ratio': max(stats_grp, key=lambda x: stats_grp[x]['pos_ratio']),
            'g_min_pos_ratio': min(stats_grp, key=lambda x: stats_grp[x]['pos_ratio']),
            # group label with max/min size
            'g_max_size': max(stats_grp, key=lambda x: stats_grp[x]['size']),
            'g_min_size': min(stats_grp, key=lambda x: stats_grp[x]['size']),
        }

    def print_distribution(self, y, s, prefix=""):
        res = {}
        for g in self.stats_grp.keys():
            msk_g = s == g
            for y_ in [0, 1]:
                y_mask = y == y_
                res[f"g={g}, y={y_}"] = (msk_g & y_mask).sum()
        print(prefix, res)

    def new_estimator(self, estimator=None, random_state=None):
        if estimator is None:
            estimator = clone(self.estimator_)
        else:
            estimator = clone(estimator)
        try:
            estimator.set_params(
                random_state=random_state,
            )
        except:
            pass
        return estimator

    def fit(self, X, y, sensitive_features: np.ndarray):
        s = sensitive_features
        self._init_stats(X, y, s)
        # empty estimators
        self.estimators_grp_ = {g: [] for g in self.stats_grp.keys()}
        self.estimators_ = []
        self.fit_iter_grp_experts(X, y, s, random_state=self.random_state)
        for i in range(self.n_estimators):
            seed = self.random_seeds[i]
            self.fit_iter_calibrated(X, y, s, random_state=seed)
        return self

    def fit_iter_grp_experts(
        self, X, y, sensitive_features: np.ndarray, random_state=None
    ):
        s = sensitive_features
        grp = self.stats_grp
        seeds = seed_generator(random_state)
        X_, y_, s_ = X.copy(), y.copy(), s.copy()

        # group-specific estimators
        for g in grp.keys():
            msk_g = s == g
            X_g, y_g, s_g = X_[msk_g], y_[msk_g], s_[msk_g]
            estimator = self.new_estimator(
                estimator=self.init_estimator, random_state=next(seeds)
            )
            try:
                estimator.fit(X_g, y_g, sensitive_features=s_g)
            except Exception as e:
                estimator.fit(X_g, y_g)
            self.estimators_grp_[g].append(estimator)

    def fit_iter_calibrated(
        self, X, y, sensitive_features: np.ndarray, random_state=None
    ):
        assert (
            min([len(clfs) for clfs in self.estimators_grp_.values()]) > 0
        ), "Group-specific experts not fitted yet."
        s = sensitive_features
        how = self.how
        grp = self.stats_grp
        meta = self.stats_meta
        seeds = seed_generator(random_state)
        X_, y_, s_ = X.copy(), y.copy(), s.copy()

        # compute target positive ratio
        if how['pos_ratio'] == 'overall':
            pos_ratio_tgt = meta['pos_ratio']
        elif how['pos_ratio'] == 'max':
            pos_ratio_tgt = grp[meta['g_max_pos_ratio']]['pos_ratio']
        elif how['pos_ratio'] == 'min':
            pos_ratio_tgt = grp[meta['g_min_pos_ratio']]['pos_ratio']
        else:
            raise ValueError(
                f"Unsupport pos_ratio: {how['pos_ratio']}, choose from ['overall', 'max', 'min']"
            )

        drop_ratio = self.how['drop_ratio']
        # print (self.estimators_grp_)
        y_pred_grp = {
            g: self.predict_estimators(
                X_, sensitive_features=s_, estimators=self.estimators_grp_[g]
            )
            for g in grp.keys()
        }
        dict_target_subgrp_labels = {g: None for g in grp.keys()}
        msk_drop = np.zeros_like(y_, dtype=bool)
        for g in grp.keys():
            rng = np.random.RandomState(
                next(seeds)
            )  # random number generator for this group
            msk_g = s_ == g
            info = grp[g]
            n_neg = info['n_neg']
            n_pos = info['n_pos']
            # if pos_ratio is already satisfied, skip
            if grp[g]['pos_ratio'] == pos_ratio_tgt:
                continue
            # compute disagreement mask
            y_pred_self = y_pred_grp[g]
            y_pred_other = np.mean(
                [y_pred_grp[g_] for g_ in grp.keys() if g_ != g], axis=0
            )

            # locating samples to drop & compute max number of samples to drop
            if info['type_group'] == 'adv':  # advantaged group
                assert (
                    info['pos_ratio'] > pos_ratio_tgt
                ), "Advantaged group should have higher pos_ratio than target."
                # drop ground truth positive samples
                msk_drop_subgrp = (y_ == 1) & msk_g
                dict_target_subgrp_labels[g] = 1
                if self.how['action'] == 'drop':
                    n_drop_max = n_pos - int((n_neg / (1 - pos_ratio_tgt)) - n_neg)
                elif self.how['action'] == 'flip':
                    n_drop_max = n_pos - int(info['size'] * pos_ratio_tgt)
                else:
                    raise ValueError(
                        f"Unsupport action: {self.how['action']}, choose from ['drop', 'flip']"
                    )

            elif info['type_group'] == 'dis':  # disadvantaged group
                assert (
                    info['pos_ratio'] < pos_ratio_tgt
                ), "Disadvantaged group should have lower pos_ratio than target."
                # drop ground truth negative samples
                msk_drop_subgrp = (y_ == 0) & msk_g
                dict_target_subgrp_labels[g] = 0
                if self.how['action'] == 'drop':
                    n_drop_max = n_neg - int((n_pos / pos_ratio_tgt) - n_pos)
                elif self.how['action'] == 'flip':
                    n_drop_max = n_neg - int(info['size'] * (1 - pos_ratio_tgt))
                else:
                    raise ValueError(
                        f"Unsupport action: {self.how['action']}, choose from ['drop', 'flip']"
                    )

            else:
                raise ValueError(f"Unrecognized group type: {grp[g]['type_group']}")

            if self.how['consensus'] == 'disag':
                msk_agree = y_pred_self != y_pred_other
            elif self.how['consensus'] == 'other-disag':
                msk_agree = (y_pred_self != y_pred_other) & (y_pred_self == y_)
            else:
                raise ValueError(
                    f"Unsupport consensus: {self.how['consensus']}, choose from ['disag', 'other-disag']"
                )

            # random sampling
            n_drop = int(n_drop_max * drop_ratio)  # number of samples to drop
            idx_agree_g = mask_to_idx(msk_drop_subgrp & msk_agree)
            idx_disag_g = mask_to_idx(msk_drop_subgrp & ~msk_agree)
            # assign sample probability
            idx_g = np.concatenate([idx_disag_g, idx_agree_g])
            p_drop_g = np.concatenate(
                [
                    np.ones_like(idx_disag_g, dtype=float),
                    np.full_like(idx_agree_g, self.how['uniformity'], dtype=float),
                ]
            )
            p_drop_g /= p_drop_g.sum()
            # print(f"G={g}, y={dict_target_subgrp_labels[g]}, subgrp_size={msk_drop_subgrp.sum()}, n_drop_max={n_drop_max}, n_drop={n_drop}, p_drop_g={p_drop_g}")
            # print(p_drop_g.max(), p_drop_g.min())
            # sample from disagreement samples
            idx_drop_g = rng.choice(idx_g, size=n_drop, replace=False, p=p_drop_g)
            # print (idx_drop_g)
            # # if enough samples with disagreement to drop
            # if n_drop <= len(idx_disag_g):
            #     idx_drop_g = rng.choice(idx_disag_g, size=n_drop, replace=False)
            # # else, drop all samples with disagreement and randomly sample from the rest
            # else:
            #     # idx_drop_g = idx_disag_g
            #     idx_drop_g = np.concatenate([
            #         idx_disag_g,
            #         rng.choice(idx_agree_g, size=n_drop - len(idx_disag_g), replace=False)
            #     ])
            msk_drop_g = idx_to_mask(idx_drop_g, len(y_))
            # update the drop mask
            msk_drop |= msk_drop_g

        # get calibrated set
        msk_keep = ~msk_drop
        if self.how['action'] == 'drop':
            X_, y_, s_ = X_[msk_keep], y_[msk_keep], s_[msk_keep]
        elif self.how['action'] == 'flip':
            y_[msk_drop] = 1 - y_[msk_drop]
        else:
            raise ValueError(
                f"Unsupport action: {self.how['action']}, choose from ['drop', 'flip']"
            )
        # bootstrap sampling
        if self.how['bootstrap']:
            X_, y_, s_ = self.bootstrap_sampling(
                X_,
                y_,
                s_,
                dict_target_subgrp_labels=dict_target_subgrp_labels,
                stratify=self.how['bootstrap'],
                random_state=random_state,
            )
        # self.print_distribution(y_, s_, "Calibrated set")
        # fit the calibrated estimator
        estimator = self.new_estimator(random_state=random_state)
        try:
            estimator.fit(X_, y_, sensitive_features=s_)
        except Exception as e:
            # print(f"Error: {e}")
            estimator.fit(X_, y_)
        self.estimators_.append(estimator)

    def predict(self, X, sensitive_features):
        y_pred_proba = self.predict_proba(X, sensitive_features)
        return np.argmax(y_pred_proba, axis=1)

    def get_estimators(self):
        pred_mode = self.how['pred']
        if pred_mode == 'adv+dis':
            estimators = self.estimators_grp_[0] + self.estimators_grp_[1]
        elif pred_mode == 'adv':
            estimators = self.estimators_grp_[self.meta_info['g_adv']]
        elif pred_mode == 'dis':
            estimators = self.estimators_grp_[self.meta_info['g_dis']]
        elif pred_mode == 'calib':
            estimators = self.estimators_
        else:
            raise ValueError(f"Unknown pred_mode: {pred_mode}")
        return estimators

    def predict_proba(self, X, sensitive_features):
        estimators = self.get_estimators()
        return self.predict_proba_estimators(X, sensitive_features, estimators)

    def predict_proba_estimators(self, X, sensitive_features, estimators: list):
        def to_one_hot(y, n_classes):
            y_one_hot = np.zeros((len(y), n_classes))
            y_one_hot[np.arange(len(y)), y] = 1
            return y_one_hot

        if self.flag_proba:
            try:
                y_pred_proba = [
                    estimator.predict_proba(X, sensitive_features=sensitive_features)
                    for estimator in estimators
                ]
            except Exception as e:
                y_pred_proba = [estimator.predict_proba(X) for estimator in estimators]
        else:
            try:
                y_pred_proba = [
                    estimator.predict(X, sensitive_features=sensitive_features)
                    for estimator in estimators
                ]
            except Exception as e:
                y_pred_proba = [estimator.predict(X) for estimator in estimators]
            y_pred_proba = [to_one_hot(y_pred, n_classes=2) for y_pred in y_pred_proba]
        y_pred_proba = np.mean(np.array(y_pred_proba), axis=0)
        return y_pred_proba

    def predict_estimators(self, X, sensitive_features, estimators: list):
        if self.flag_proba:
            y_pred_proba = self.predict_proba_estimators(
                X, sensitive_features, estimators
            )
            return np.argmax(y_pred_proba, axis=1)
        else:
            try:
                y_pred = [
                    estimator.predict(X, sensitive_features=sensitive_features)
                    for estimator in estimators
                ]
            except Exception as e:
                y_pred = [estimator.predict(X) for estimator in estimators]
            return np.mean(np.array(y_pred), axis=0)

    @staticmethod
    def bootstrap_sampling(
        X, y, s, dict_target_subgrp_labels=None, stratify='regular', random_state=None
    ):
        rng = np.random.RandomState(random_state)
        if stratify is 'regular':
            idx = rng.choice(len(y), size=len(y), replace=True)
        elif stratify == 'y':
            idx = np.concatenate(
                [
                    rng.choice(mask_to_idx(y == y_), size=(y == y_).sum(), replace=True)
                    for y_ in np.unique(y)
                ]
            )
        elif stratify == 's':
            idx = np.concatenate(
                [
                    rng.choice(mask_to_idx(s == s_), size=(s == s_).sum(), replace=True)
                    for s_ in np.unique(s)
                ]
            )
        elif stratify == 'y+s':
            idx = np.concatenate(
                [
                    rng.choice(
                        mask_to_idx((y == y_) & (s == s_)),
                        size=((y == y_) & (s == s_)).sum(),
                        replace=True,
                    )
                    for y_ in np.unique(y)
                    for s_ in np.unique(s)
                ]
            )
        elif stratify == 'untargeted':
            idxs = []
            for g in np.unique(s):
                for y_ in np.unique(y):
                    msk = (y == y_) & (s == g)
                    if y_ == dict_target_subgrp_labels[g]:
                        # if the target subgrp is already undersampled, do not sample from it
                        idxs.append(mask_to_idx(msk))
                    else:
                        # if the target subgrp is not undersampled, bootstrap sample from it
                        idxs.append(
                            rng.choice(mask_to_idx(msk), size=msk.sum(), replace=True)
                        )
            idx = np.concatenate(idxs)
        else:
            raise ValueError(f"Unknown bootstrap stratify: {stratify}")
        return X[idx], y[idx], s[idx]
    

def load_data(name):
    src_path = os.path.dirname(os.path.abspath(__file__))
    dataset_path = os.path.join(src_path, '..', 'Datasets/processed dataset/')
    name = name + '.csv'
    print(dataset_path)
    if name == 'credit_approval.csv':  # Gender
        sen_index = 1
        #在属性中的列下标
        cat_indices = [0, 3, 4, 5, 6, 8, 9, 11, 12]
        cont_indices = [1, 2, 7, 10, 13, 14]
    elif name == 'adult.csv':  # sex
        sen_index = 8
        cat_indices = [1, 2, 3, 4, 5, 6, 7, 9]
        cont_indices = [0, 8]
    elif name == 'credit_default.csv':  # SEX
        sen_index = 2#这是包括了标签的
        cat_indices = [1, 2, 3, 5, 6, 7, 8, 9, 10]
        cont_indices = [0, 4, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22]
    elif name == 'law_admission.csv':  # gender
        sen_index = 6
        cat_indices = [4, 5, 6]
        cont_indices = [0, 1, 2, 3]
    elif name == 'german.csv':  # Gender
        sen_index = 1
        cat_indices = [0, 1, 2, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26]
        cont_indices = [3, 4, 5]
    elif name == 'por.csv':  # sex
        sen_index = 2
        cat_indices = [0, 1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27]
        cont_indices = [2, 28, 29, 30]

    data_path = os.path.join(dataset_path, name)
    if not os.path.exists(data_path):
         print(f"Error: File not found at {data_path}")
         return None, None, None, None
    data = pd.read_csv(data_path)

    label_col = data.iloc[:, 0]
    sensi_col = data.iloc[:, sen_index]
    print(f"\nDataset: {name}")
    
    sen_col_ratio = sensi_col.value_counts(normalize=True)
    print(f"\n✅ 敏感属性列的值比例：")
    print(sen_col_ratio)

    col0_ratio = label_col.value_counts(normalize=True)
    print("\n✅ 标签的值比例：")
    print(col0_ratio)

    pos_label = 1
    groups = sensi_col.unique()
    if len(groups) >= 2:
        g1, g2 = groups[0], groups[1]
        g1_data = data[sensi_col == g1]
        g2_data = data[sensi_col == g2]

        g1_pos_rate = (g1_data.iloc[:, 0] == pos_label).sum() / len(g1_data)
        g2_pos_rate = (g2_data.iloc[:, 0] == pos_label).sum() / len(g2_data)
        dp = abs(g1_pos_rate - g2_pos_rate)
        print(f"\n✅ DP值（|P({g1}, 正例) - P({g2}, 正例)|）：{dp:.6f}")

        # 输出更详细的统计信息
        g1_total = len(g1_data)
        g1_pos = (g1_data.iloc[:, 0] == pos_label).sum()
        g2_total = len(g2_data)
        g2_pos = (g2_data.iloc[:, 0] == pos_label).sum()
        pos_total = (label_col == pos_label).sum()
        neg_total = (label_col != pos_label).sum()

        print(f"\n🔍 样本详细统计：")
        print(f" 敏感属性 - {g1} 样本总数：{g1_total}")
        print(f" 敏感属性 - {g1} 中正样本数：{g1_pos}")
        print(f" 敏感属性 - {g2} 样本总数：{g2_total}")
        print(f" 敏感属性 - {g2} 中正样本数：{g2_pos}")
        print(f"  - 正类样本总数：{pos_total}")
        print(f"  - 负类样本总数：{neg_total}")
    
    return data, sen_index - 1, cat_indices, cont_indices

def main():
    filenames = ['credit_approval', 'adult', 'credit_default', 'law_admission','german', 'por']
    n_splits = 5
    random_seed = 42
    np.random.seed(42)
    src_path = os.path.dirname(os.path.realpath('__file__'))
    single_ensemble_kwargs = {
        'n_estimators': 1,
        'random_state': 0,
    }
    result_path = os.path.join(src_path, 'FairGBFC/comparative_result/')
    for d in range(len(filenames)):
        dataset_name = filenames[d]
        file_path = result_path + '/Fair/GroupDebias/' + str(dataset_name) + '.txt'
        file = open(file_path, mode='a')
        file.write(filenames[d] + '\n\n')
        data_frame, sen_index, cat_indices, cont_indices = load_data(dataset_name)
        args.sen_index = sen_index
        print(data_frame)
        print(args.sen_index)
        data = data_frame.values
        data_temp = []
        data_list = data.tolist()
        data = []
        for data_single in data_list:
            if data_single[1:] not in data_temp:
                data_temp.append(data_single[1:])
                data.append(data_single)
        data = np.array(data)
        numberSample = data.shape[0]

        minMax = MinMaxScaler()
        data = np.hstack((data[:, 0].reshape(numberSample, 1),
                            minMax.fit_transform(data[:, 1:])))
        train_data = data[:, 1:]
        train_target = data[:, 0]

        skf = StratifiedKFold(n_splits, shuffle=True, random_state=42)

        acc_list, f1_list, recall_list = [], [], []
        dp_list, ge_list, eo_list = [], [], []


        for train_index, test_index in skf.split(train_data, train_target):
            train, test = data[train_index], data[test_index]
            X_train, y_train = train[:, 1:], train[:, 0]
            X_test, Y_test = test[:, 1:], test[:, 0]
            S_test = X_test[:, args.sen_index]

            mlp = MLPClassifier(
                hidden_layer_sizes=(128, 64, 32, 16),
                activation='relu',
                solver='adam',
                alpha=0.0001,
                batch_size='auto',
                learning_rate='constant',
                learning_rate_init=0.001,
                max_iter=500,
                early_stopping=True,
                validation_fraction=0.1,
                n_iter_no_change=10,
                random_state=42
            )

            model = FairEnsemble(
                estimator=mlp,
                how={'pos_ratio': 'max', 'uniformity': 0.1, 'drop_ratio': 1, 'bootstrap': False},
                **single_ensemble_kwargs
            )
            sensitive_train = X_train[:, args.sen_index]

            model.fit(X_train, y_train, sensitive_features=sensitive_train)
            predict_label = model.predict(X_test, sensitive_features=S_test)


            acc = accuracy_score(Y_test, predict_label)
            f1 = f1_score(Y_test, predict_label)
            recall = recall_score(Y_test, predict_label)
            dp = demographic_parity_difference(Y_test, predict_label, sensitive_features=S_test)
            eo = equalized_odds_difference(Y_test, predict_label, sensitive_features=S_test)
            ge = generalized_entropy_error(Y_test, predict_label)


            acc_list.append(acc)
            f1_list.append(f1)
            recall_list.append(recall)

            dp_list.append(dp)
            ge_list.append(ge)
            eo_list.append(eo)
        def avg_var_str(name, values):
            return (f'Average {name} of {filenames[d]}: {np.mean(values):.6f}\n'
                    f'Std {name} of {filenames[d]}: {np.std(values):.6f}\n')

        file.write(avg_var_str('accuracy', acc_list))
        file.write(avg_var_str('f1', f1_list))
        file.write(avg_var_str('recall', recall_list))
        file.write(avg_var_str('dp', dp_list))
        file.write(avg_var_str('eo', eo_list))
        file.write(avg_var_str('ge', ge_list))
        file.write('\n')
        print(filenames[d], 'done!')
    file.write('all done!!!!!')
    file.close()


if __name__ == '__main__':
    main()