import os
import numpy as np
import torch


def generate_save_path(data_dir, num_tasks, support_size, query_size, mode):
    save_path = f'{data_dir}/pendulum_n_tasks_{num_tasks}_n_supp_{support_size}_n_query_{query_size}_{mode}.pkl'
    return save_path


def save_test_results(epoch, path, params, total_reward, final_reward):
    data_to_save = np.array([[epoch, params[0], params[1], params[2], total_reward, final_reward]])
    if os.path.exists(path):
        data = np.load(path)
        data = np.append(data, data_to_save, 0)
        np.save(path, data)
    else:
        np.save(path, data_to_save)


def sample_params(rng, g_int, m_int, l_int):
    g = rng.uniform(g_int[0], g_int[1])
    m = rng.uniform(m_int[0], m_int[1])
    l = rng.uniform(l_int[0], l_int[1])
    return g, m, l


def to_numpy(tensor):
    if isinstance(tensor, torch.Tensor):
        return tensor.cpu().detach().numpy()


def params_to_gi(random_g, random_m, random_l, device):
    g_i = torch.from_numpy(np.array([random_g, random_m, random_l])).view(1, -1).float().to(device)
    return g_i