from config import *
from data_loader.real_feedback_data_loader import load_real_data
from data_loader.synthetic_data_loader import load_synthetic_data
from ast import literal_eval
import re
import pandas as pd

def generator(dfx):
    random.seed(SEED)
    np.random.seed(SEED)
    tf.random.set_seed(SEED)
    for _, row in dfx.iterrows():
        query_response_ragged = tf.ragged.constant(row['query_response'])
        yield (
            row['user_id'],
            row['main_vector'],
            row['feedback'],
            # row['query_response']
            query_response_ragged
        )


def test_generator(dfx):
    random.seed(SEED)
    np.random.seed(SEED)
    tf.random.set_seed(SEED)
    for _, row in dfx.iterrows():
        query_response_ragged = tf.ragged.constant(row['query_response'])
        yield (
            row['user_id'],
            row['main_vector'],
            row['feedback'],
            # row['query_response']
            query_response_ragged
        )


def synthetic_split(result_df_, vals):
    random.seed(SEED)
    np.random.seed(SEED)
    tf.random.set_seed(SEED)
    train_data, test_data = [], []
    err_count = 0

    if vals % 2 == 0:
        npx, nnx = vals // 2, vals // 2
    else:
        npx, nnx = vals // 2, vals // 2
        if np.random.rand() <= 0.5:
            npx += 1
        else:
            nnx += 1

    for user in result_df_.user_id.unique():
        temp_df = result_df_[result_df_.user_id == user]
        positive_label = temp_df[temp_df.feedback == 1]
        negative_label = temp_df[temp_df.feedback == 0]
        if len(positive_label) < 2 or len(negative_label) < 2:
            err_count += 1
            continue
        ptrain_indices = np.random.choice(
            positive_label.index.values.tolist(), size=npx, replace=False
        )
        ptest_indices = np.setdiff1d(positive_label.index.values, ptrain_indices)
        ntrain_indices = np.random.choice(
            negative_label.index.values.tolist(), size=nnx, replace=False
        )
        ntest_indices = np.setdiff1d(negative_label.index.values, ntrain_indices)
        train_indices = np.append(ptrain_indices, ntrain_indices)
        test_indices = np.append(ptest_indices, ntest_indices)
        train_data.extend(temp_df.loc[train_indices].values.tolist())
        test_data.extend(temp_df.loc[test_indices].values.tolist())

    df_train = pd.DataFrame(train_data, columns=result_df_.columns)
    df_test = pd.DataFrame(test_data, columns=result_df_.columns)
    # print("{}/{} users missing".format(err_count, result_df_.user_id.nunique()))
    return df_train, df_test


def real_split(result_df_, vals):
    random.seed(SEED)
    np.random.seed(SEED)
    tf.random.set_seed(SEED)
    train_data = []
    test_data = []
    for uid in result_df_.user_id.unique():
        temp_df = result_df_[result_df_['user_id'] == uid]
        trainx = temp_df.sample(min(vals, len(temp_df) // 2))
        train_indices = trainx.index.values
        testx = temp_df[~temp_df.index.isin(train_indices)]
        train_data.extend(trainx.values)
        test_data.extend(testx.values)
    return pd.DataFrame(train_data, columns=result_df_.columns), pd.DataFrame(test_data, columns=result_df_.columns)


def get_parameters_from_string(input_str):
    pattern = r'(\w+):\s*([^,\}\]]+)'
    matches = re.findall(pattern, input_str)
    ps = {}

    num_regex = r'^[+-]?(\d+(\.\d*)?|\.\d+)([eE][+-]?\d+)?$'

    for key, value in matches:
        val_str = value.strip()
        if not re.fullmatch(num_regex, val_str) and val_str not in ("True", "False"):
            if not ((val_str.startswith('"') and val_str.endswith('"')) or
                    (val_str.startswith("'") and val_str.endswith("'"))):
                val_str = f'"{val_str}"'
        try:
            evaluated_value = literal_eval(val_str)
            ps[key] = evaluated_value
        except Exception:
            evaluated_value = val_str
            ps[key] = eval(evaluated_value)
    return ps


def load_data(dataset: str, input_dir: str):
    if dataset == 'synthetic':
        return load_synthetic_data(input_dir=input_dir)
    elif dataset == 'real':
        return load_real_data(input_dir=input_dir)
    else:
        raise ValueError(f"Unknown dataset: {dataset}")



