import random
import torch
import pyspiel
import os.path as osp
import os

game_name = "kuhn_poker"
player_number = 2
game = pyspiel.load_game("kuhn_poker", {"players": player_number})
model_number = 1
seed = 1


# load expert dataset and random dataset
offline_data_location_1 = "expert_data/kuhn_poker_2_player/game_kuhn_poker_seed_1_train_100_iteration_sample_episode_1000_deep_cfr.pth"
offline_data_location_2 = "random_data/kuhn_poker_2_player/game_kuhn_poker_seed_10000_episode_1000.pth"

offline_data_1 = torch.load(offline_data_location_1)
offline_data_2 = torch.load(offline_data_location_2)

for i in range(11):
    for data_number in [10000, 20000, 50000]:
        # sample data
        pro = i / 10.0
        slice_1 = random.sample(offline_data_1, data_number * (1 - pro))
        slice_2 = random.sample(offline_data_1, data_number * pro)

        S = slice_1 + slice_2

        result_dir = "hybrid_dataset/" + game_name + "_" + str(player_number) + "_players/" + "seed_1_random_dataset_proportion_{}".format(i)
        if not osp.exists(result_dir):
            os.makedirs(result_dir)

        buffer_name = "data_number_{}.pth".format(data_number)
        torch.save(S, osp.join(result_dir, buffer_name))


