import numpy as np
from sympy import true
np.random.seed(0)
import torch
import os
from ptranking.utils.bigdata.BigPickle import pickle_save, pickle_load
from ptranking.data.data_utils import LTRDataset
from ptranking.data.data_utils import get_data_meta, SPLIT_TYPE, LABEL_TYPE


import sys
from ptranking.utils.args.argsUtil import ArgsUtil

from ptranking.ltr_adhoc.eval.parameter import ModelParameter, DataSetting, EvalSetting, PriParameter, ScoringFunctionParameter

from ptranking.ltr_adhoc.eval.ltr import LTREvaluator, LTR_ADHOC_MODEL, YAHOO_LTR, ISTELLA_LTR

def _aug_dataset(dataset, pri_feature_config, fold_k):
    if dataset == None:
        return

    if pri_feature_config.json_dict['type'][0] in ["Gumbel"]:
        seed_id = " ".join([str(np.random.randint(100)) for _ in range(5)])
        print("Seed {0} id is : " + seed_id)

        temperature = pri_feature_config.json_dict['temperature'][0]
        cache_file = f"Pri_Feature/fold-{fold_k}/" \
            + str(dataset.split_type) + "_" \
            + "-".join([pri_feature_config.json_dict['type'][0],
                        str(temperature)])
        cache_file += ".torch"
        cache_file = dataset.dir_data + cache_file
        os.makedirs(os.path.dirname(cache_file), exist_ok=True)


        print("Generating Privileged Features - " +
                pri_feature_config.json_dict['type'][0])
        pri_features = []

        for qid, features, y in dataset:
            item_score = temperature * y + np.random.gumbel(size=y.shape)
            alternative_score = torch.tensor(np.random.gumbel(size=y.shape))

            pri_features.append((qid, item_score, alternative_score))
        pickle_save(pri_features, cache_file)






def determine_files(data_dict, fold_k=None):
        """
        Determine the file path correspondingly.
        :param data_dict:
        :param fold_k:
        :return:
        """
        if data_dict['data_id'] in YAHOO_LTR:
            file_train, file_vali, file_test = os.path.join(data_dict['dir_data'], data_dict['data_id'].lower() + '.train.txt'),\
                os.path.join(data_dict['dir_data'], data_dict['data_id'].lower() + '.valid.txt'),\
                os.path.join(data_dict['dir_data'],
                             data_dict['data_id'].lower() + '.test.txt')

        elif data_dict['data_id'] in ISTELLA_LTR:
            if data_dict['data_id'] == 'Istella_X' or data_dict['data_id'] == 'Istella_S':
                file_train, file_vali, file_test = data_dict['dir_data'] + \
                    'train.txt', data_dict['dir_data'] + \
                    'vali.txt', data_dict['dir_data'] + 'test.txt'
            else:
                file_vali = None
                file_train, file_test = data_dict['dir_data'] + \
                    'train.txt', data_dict['dir_data'] + 'test.txt'
        else:
            fold_k_dir = data_dict['dir_data'] + 'Fold' + str(fold_k) + '/'
            file_train, file_vali, file_test = fold_k_dir + \
                'train.txt', fold_k_dir + 'vali.txt', fold_k_dir + 'test.txt'

        return file_train, file_vali, file_test


def load_data(eval_dict, data_dict, fold_k, pri_feature_config=None):
    """
    Load the dataset correspondingly.
    :param eval_dict:
    :param data_dict:
    :param fold_k:
    :param model_para_dict:
    :return:
    """
    file_train, file_vali, file_test = determine_files(
        data_dict, fold_k=fold_k)

    train_batch_size, train_presort = data_dict['train_batch_size'], data_dict['train_presort']
    # required when enabling masking data
    input_eval_dict = eval_dict if eval_dict['mask_label'] else None
    train_data = LTRDataset(file=file_train,
                            split_type=SPLIT_TYPE.Train,
                            batch_size=train_batch_size,
                            shuffle=True,
                            presort=train_presort,
                            data_dict=data_dict,
                            eval_dict=input_eval_dict)

    test_data = LTRDataset(file=file_test,
                           split_type=SPLIT_TYPE.Test,
                           shuffle=False,
                           data_dict=data_dict,
                           batch_size=data_dict['test_batch_size'])

    if data_dict['data_id'] != "Istella" and (eval_dict['do_validation'] or eval_dict['do_summary']):  # vali_data is required
        vali_data = LTRDataset(file=file_vali,
                                split_type=SPLIT_TYPE.Validation,
                                shuffle=False,
                                batch_size=data_dict['validation_batch_size'],data_dict=data_dict)
    else:
        vali_data = None

    return train_data, vali_data, test_data


if __name__ == '__main__':

    print("################################")
    print(f"# Start to generate data with seed {0}")
    print("################################")

    args_obj = ArgsUtil(given_root='./')
    l2r_args = args_obj.get_l2r_args()

    dir_json = l2r_args.dir_json
    data_eval_sf_json = dir_json + 'Data_Eval_ScoringFunction.json'

    eval_setting = EvalSetting(debug=False, eval_json=data_eval_sf_json)
    data_setting = DataSetting(data_json=data_eval_sf_json)
    sf_parameter = ScoringFunctionParameter(sf_json=data_eval_sf_json)
    model_parameter = ModelParameter(model_id=data_eval_sf_json)
    pri_setting = PriParameter(para_json=data_eval_sf_json)

    for eval_dict in eval_setting.grid_search():
        pass
    for data_dict in data_setting.grid_search():
        pass
    fold_num = data_dict['fold_num']
    print(fold_num)

    for fold_k in range(1, fold_num + 1):
        train_data, vali_data, test_data = load_data(eval_dict, data_dict, fold_k, pri_setting)
        for dataset in [train_data, vali_data, test_data]:
            _aug_dataset(dataset, pri_setting, fold_k)

    print("################################")
    print(f"# Finish generating data with seed {0}")
    print("################################")
