import numpy as np
import random
from sklearn.model_selection import train_test_split
from configs import DataConfig, ModelConfig, model_lookup

data_map = {20172019: ['2017', '2018', '2019'],
            20172022:['2017', '2018', '2019', '2020', '2021', '2022'],
            20202022:['2020', '2021', '2022'],
            20232024: ['2023', '2024']}

def concat_year_data(X, past_years, future_years, topic, layer, random_seed):
    np.random.seed = random_seed
    X_past = np.concatenate([X[topic][year][layer] for year in past_years])
    X_future = np.concatenate([X[topic][year][layer] for year in future_years])

    samp_size = min(X_past.shape[0], X_future.shape[0])
    sample_idx = np.random.choice(X_past.shape[0], size=samp_size, replace=False)
    X_past = X_past[sample_idx,:]

    y_past = np.zeros(X_past.shape[0])
    y_future = np.ones(X_future.shape[0])

    X = np.concatenate([X_past, X_future])
    y = np.concatenate([y_past, y_future])

    return X, y


def get_single_topic_data(X_dict, data_config, topic, layer, random_seed):
    np.random.seed = random_seed
    past_years = data_map[data_config.past_years]
    future_years = data_map[data_config.future_years]

    X, y = concat_year_data(X_dict, past_years, future_years, topic, layer, random_seed)

    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, stratify=y, random_state=random_seed)
    
    return X_train, X_test, y_train, y_test


def get_hold_one_out_data(X_dict, data_config, test_topic, layer, random_seed):
    np.random.seed = random_seed
    past_years = data_map[data_config.past_years]
    future_years = data_map[data_config.future_years]
    X_train, y_train = [], []
    
    for topic in data_config.topics:
        print(topic)
        if topic!=test_topic:
            X, y = concat_year_data(X_dict, past_years, future_years, topic, layer, random_seed)
            X_train.append(X)
            y_train.append(y)
        else:
            X_test, y_test = concat_year_data(X_dict, past_years, future_years, topic, layer, random_seed)
        print(len(X_train))
    X_train = np.concatenate(X_train)
    y_train = np.concatenate(y_train)
    
    return X_train, X_test, y_train, y_test

def get_mixed_data(X_dict, data_config, layer, random_seed):
    np.random.seed = random_seed
    past_years = data_map[data_config.past_years]
    future_years = data_map[data_config.future_years]
    X_list, y_list = [], []
    
    for topic in data_config.topics:
        X, y = concat_year_data(X_dict, past_years, future_years, topic, layer, random_seed)
        X_list.append(X)
        y_list.append(y)

    X = np.concatenate(X_list)
    y = np.concatenate(y_list)
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, stratify=y, random_state=random_seed)

    return X_train, X_test, y_train, y_test




                


