import os
import torch
import numpy as np
import pandas as pd
from rlkit.torch.networks import stochastic_actor2

def load_samples(env):
    current_dir = os.path.dirname(os.path.abspath(__file__))
    parquet_path = os.path.join(current_dir, env, 'expert_data.parquet')
    df = pd.read_parquet(parquet_path)
    obs = np.stack(df['obs'].values)
    actions = np.stack(df['action'].values)

    return obs, actions

def load_transitions(env,obs_dim,action_dim,replay=False):
    current_dir = os.path.dirname(os.path.abspath(__file__))
    if replay:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        policy_params = torch.load(f'./reference_data/{env}/expert_policy.pt')
        for name, param in policy_params.items():
            if name == 'shared_layer.fc2.bias':
                net_size = int(param.numel())
            if name == 'fc_mean.weight':
                latent_action_dim = int(param.numel()/action_dim)
        expert_policy = stochastic_actor2(obs_dim,
                                          action_dim,
                                          net_size,
                                          latent_dim=latent_action_dim).to(device)
        policy_params = torch.load(f'./reference_data/{env}/expert_policy.pt')
        expert_policy.load_state_dict(policy_params)
        
        parquet_path = os.path.join(current_dir, env, 'replay_obs_data.parquet')
        df = pd.read_parquet(parquet_path)
        obs = np.stack(df['obs'].values)
        actions = expert_policy.select_action(torch.Tensor(obs).to(device),deterministic=True)
        print("Expert actions are ready.")
        return obs, actions, None, None, None
    else:
        parquet_path = os.path.join(current_dir, env, 'expert_data.parquet')
        df = pd.read_parquet(parquet_path)
        obs = np.stack(df['obs'].values)
        actions = np.stack(df['action'].values)
        rewards = np.stack(df['reward'].values)
        next_obs = np.stack(df['next_obs'].values)
        dones = np.stack(df['done'].values)
        
        return obs, actions, rewards, next_obs, dones