import d3rlpy
import torch
from denoising_diffusion_pytorch import Unet, GaussianDiffusion, Unet1D, GaussianDiffusion1D, Trainer1D, Dataset1D

# Generate data from d3rlpy.datasets.get_pendulum

def TrueGenerator(env, policy, lim):
    '''
        env: Environment to generate data
        policy: trained policy
        num: total number of data
        n: size of raw data
        m: size of generated data 
        lim: size of data to mimic the underlying true  

        Note: "n+m == num -2n" must be satisfied.
    '''
    lim_data = []

    for i in range(0, lim):
        buffer = d3rlpy.dataset.create_fifo_replay_buffer(
            limit=100000, env=env)  # start data collection
        # n_step is to specify the length of each episode
        policy.collect(env, buffer, n_steps=128)
        # save ReplayBuffer
        with open(f"1Dtrained_policy_dataset_lim_{lim}.h5", "w+b") as g:
            buffer.dump(g)
        # read ReplayBuffer
        with open(f"1Dtrained_policy_dataset_lim_{lim}.h5", "rb") as g:
            lim_dataset = d3rlpy.dataset.ReplayBuffer.load(
                g, d3rlpy.dataset.InfiniteBuffer())
        lim_data.append(lim_dataset.episodes)

    # extract reward from new_data
    lim_reward = []
    for i in range(0, lim):
        lim_reward.append(torch.tensor(lim_data[i][0].rewards).T)

    # used to approximate real reward
    lim_reward = torch.stack(lim_reward, dim=0)

    # Underlying true average reward
    average_real_reward = torch.mean(torch.mean(lim_reward, dim=2))

    return average_real_reward

if __name__ == "__main__":
    
    import sys

    lim = int(sys.argv[1]) 
    
    _, env = d3rlpy.datasets.get_pendulum()
    cql_pretrained = d3rlpy.load_learnable(f"cql_pretrained.d3")

    true_value = TrueGenerator(env, cql_pretrained, lim)
    numerical_value = true_value.item()    

    with open('truevalue.txt', 'a') as file:
        file.write(str(numerical_value))

    