import os
import hydra
from omegaconf import DictConfig, OmegaConf
from hydra.utils import instantiate, get_original_cwd
import pickle
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
#Load repeat_static, rest of import is same as train
from src.utils.utils import repeat_static
from src.utils.utils import set_seed

@hydra.main(config_name='config.yaml', config_path='../configs/', version_base=None)
def main(args: DictConfig):
    #Dataset generation and storage only!
    original_cwd = get_original_cwd()
    model_type = args.model.name.lower()
    dataset_name = args.dataset.name

    processed_data_base = os.path.join(original_cwd, 'data/processed')
    data_dir = os.path.join(processed_data_base, dataset_name, model_type)
    os.makedirs(data_dir, exist_ok=True)

    #Path rules and train alignment
    seed = args.exp.seed
    path = os.path.join(data_dir, f"seed_{seed}.pkl")

    set_seed(args.exp.seed)

    print(f"Generating and saving dataset to: {path}")

    dataset_collection = instantiate(args.dataset, _recursive_=True)

    if model_type in ['crn', 'ct', 'rmsn']:
        print("Processing data with encoder method...")
        dataset_collection.process_data_encoder()
    else:
        print("Processing data with multi method...")
        dataset_collection.process_data_multi()

    #Compatible with static feature extensions
    if model_type in ['gift', 'vcip', 'actin']:
        if getattr(args.dataset, 'static_size', 0) > 0:
            dims = len(dataset_collection.train_f.data['static_features'].shape)
            if dims == 2:
                dataset_collection = repeat_static(dataset_collection)
                print("Static features repeated for sequence compatibility")

    with open(path, 'wb') as file:
        pickle.dump(dataset_collection, file)
        print(f"Dataset object saved at: {path}")

if __name__ == "__main__":
    main()
