import torch
from sklearn.model_selection import StratifiedKFold
from models.tab_transformer_pytorch import get_model_and_train_components, tab_transformer_train, tab_transformer_eval
from utils import split_cont_categ,ColInfo
import os
from augmentation import CategoricalFeatureConverter
from sklearn.preprocessing import StandardScaler, MinMaxScaler


def training_fe(k, config, train_set, test_set, cols_info_tuple:ColInfo, save_path):
    """ One round training
    """
    seed = config['split_seed']
    train_set = (train_set[0].reset_index(drop=True), train_set[1].reset_index(drop=True))
    test_set = (test_set[0].reset_index(drop=True), test_set[1].reset_index(drop=True))

    if k > 1:
        kfold = StratifiedKFold(n_splits=k, shuffle=True, random_state=seed)
        kfold_set = []
        for train_ind, val_ind in kfold.split(train_set[0], train_set[1]):
            kfold_set.append((train_ind, val_ind))

    for i in range(k):
        if k > 1:
            x_train = train_set[0].iloc[kfold_set[i][0].tolist()]
            y_train = train_set[1].iloc[kfold_set[i][0].tolist()]
            x_val = train_set[0].iloc[kfold_set[i][1].tolist()]
            y_val = train_set[1].iloc[kfold_set[i][1].tolist()]
        else:
            x_train, x_val, y_train, y_val = train_set[0], train_set[0], train_set[1], train_set[1]

        x_train, x_val, y_train, y_val = x_train.reset_index(drop=True), x_val.reset_index(drop=True), \
                                         y_train.reset_index(drop=True), y_val.reset_index(drop=True)

        print('train set shape', len(x_train))
        print('val set shape', len(x_val))

        print('training .....')
        train_cont, train_categ = split_cont_categ(x_train, cols_info_tuple.cont_name, cols_info_tuple.cate_idxs)
        val_cont, val_categ = split_cont_categ(x_val, cols_info_tuple.cont_name, cols_info_tuple.cate_idxs)
        train_target = y_train.to_numpy()
        val_target = y_val.to_numpy()
        cls, criterion, optimizer, scheduler = get_model_and_train_components(cols_info_tuple.cate_dims,
                                                                              len(cols_info_tuple.cont_name),
                                                                              config['model_params'])
        best_model_dict = tab_transformer_train(
            cls,
            criterion,
            optimizer,
            scheduler,
            train_cont,
            train_categ,
            train_target,
            val_cont,
            val_categ,
            val_target,
            device="cuda:0",
            **config['train_params'],
            save_best_model_dict=False,
            save_metrics=False,
        )
        cls.load_state_dict(best_model_dict)

        f1_val = tab_transformer_eval(cls, val_cont, val_categ, val_target)
        test_cont, test_categ = split_cont_categ(test_set[0], cols_info_tuple.cont_name, cols_info_tuple.cate_idxs)
        test_target = test_set[1].to_numpy()
        f1_test = tab_transformer_eval(cls, test_cont, test_categ, test_target)

        print("Val score", f1_val)
        print("Test score", f1_test)

        save_name = os.path.join(save_path,'fe_model.pth') if k==1 else os.path.join(save_path, str(i) +'_fe_model.pth')
        torch.save({
            'state_dict': best_model_dict,
            'val_perf': f1_val,
            'test_perf': f1_test}, save_name)


class FeatureExtractor(object):
    def __init__(self, scaler, device="cuda:0"):
        # self.feature_extractor = model
        # self.feature_extractor.eval()
        # self.feature_extractor.to(device)
        self.device = device
        self.scaler = scaler

    # def get_features(self, cate_datas, cont_datas, is_logits=False):
    #     with torch.no_grad():
    #         cate_datas, cont_datas = cate_datas.to(self.device), cont_datas.to(self.device)
    #         features = self.feature_extractor.extract_feat(cate_datas, cont_datas, feat_layer_index=-5)
    #     return features

    # def extractor_features_from_dst(self, dst, num_classes, cols_info_tuple:ColInfo, is_logits=False, batch_size=256, is_split_by_class=True):
    #     feature_list = []
    #     label_list = []
    #     dst_cont, dst_categ = split_cont_categ(dst[0], cols_info_tuple.cont_name, cols_info_tuple.cate_idxs)
    #     dst_target = dst[1].to_numpy()
    #     for i in range(dst_categ.shape[0] // batch_size + 1):
    #         begin_index = batch_size * i
    #         end_index = batch_size * i + batch_size if batch_size * i + batch_size <= dst_categ.shape[0] else dst_categ.shape[0]
    #         x_categ = torch.tensor(dst_categ[begin_index:end_index]).to(dtype=torch.int64,
    #                                                                                            device=self.device)
    #         x_cont = torch.tensor(dst_cont[begin_index:end_index]).to(dtype=torch.float32,
    #                                                                                          device=self.device)
    #         y_target = torch.tensor(dst_target[begin_index:end_index])\
    #             .to(dtype=torch.int, device=self.device)
    #         features = self.get_features(x_categ, x_cont, is_logits=is_logits).squeeze(dim=0)
    #         feature_list.append(features)
    #         label_list.append(y_target)
    #
    #     feature_list = torch.cat(feature_list, dim=0)
    #     label_list = torch.cat(label_list, dim=0)
    #     if len(dst[0]) != feature_list.shape[0]:
    #         raise Exception('extractor features error!')
    #
    #     if is_split_by_class:
    #         # split feature by class
    #         class_split_index = []
    #         for c in range(num_classes):
    #             class_split_index.append([])
    #         for i in range(len(dst[0])):
    #             class_split_index[label_list[i]].append(i)
    #         class_split_features = {}
    #         for c in range(num_classes):
    #             class_split_features[c] = feature_list[class_split_index[c], ...]
    #
    #         return class_split_features
    #     else:
    #         return feature_list

    def split_into_class(self, feature_list, num_classes, label_list):
        # split feature by class
        class_split_index = []
        for c in range(num_classes):
            class_split_index.append([])
        for i in range(feature_list.shape[0]):
            class_split_index[label_list[i]].append(i)
        class_split_features = {}
        for c in range(num_classes):
            class_split_features[c] = feature_list[class_split_index[c], ...]
        return class_split_features

    def extractor_features_from_dst(self, dst, num_classes, cols_info_tuple:ColInfo, train_val_index=None,
                                    is_split_by_class=True):
        cate_converter = CategoricalFeatureConverter(cate_col=cols_info_tuple.cate_name)
        feature_list = cate_converter.convert_to_one_hot_labels(dst[0])
        numeric_col = cols_info_tuple.cont_name
        if numeric_col is not None and len(numeric_col) > 0 and self.scaler is not None:
            feature_list[numeric_col] = self.scaler.transform(feature_list[numeric_col])

        feature_list = torch.tensor(feature_list.values, dtype=float).to(self.device)

        if train_val_index is not None:
            train_feature_list = feature_list[train_val_index[0], :]
            val_feature_list = feature_list[train_val_index[1], :]

            if is_split_by_class:
                return self.split_into_class(train_feature_list, num_classes, dst[1].iloc[train_val_index[0].tolist()].values),\
                        self.split_into_class(val_feature_list, num_classes, dst[1].iloc[train_val_index[1].tolist()].values)
        else:
            if is_split_by_class:
                return self.split_into_class(feature_list, num_classes, dst[1].values)
            else:
                return feature_list