import run 
import sys
import random
import torch
import d3rlpy

random.seed(99)

n = int(sys.argv[1]) 
index = int(sys.argv[2])
steps = int(sys.argv[3])
num = n

def main(env, policy, num, n, true, index, steps):

    average_reward_SYN_1_100, average_reward_SYN_2_100, average_reward_SYN_1_500, average_reward_SYN_2_500, average_reward_SYN_1_1000, average_reward_SYN_2_1000, average_reward_1, average_reward_2 = run.run(env = env, policy = policy, num = num, n = n, index = index, train_step = steps)

    syn_1_error_100 = abs(average_reward_SYN_1_100 - true)
    syn_2_error_100 = abs(average_reward_SYN_2_100 - true)
    syn_1_error_500 = abs(average_reward_SYN_1_500 - true)
    syn_2_error_500 = abs(average_reward_SYN_2_500 - true)
    syn_1_error_1000 = abs(average_reward_SYN_1_1000 - true)
    syn_2_error_1000 = abs(average_reward_SYN_2_1000 - true)

    batch1_error = abs(average_reward_1 - true)
    batch2_error = abs(average_reward_2 - true)

    res = 'steps = {}, n = {}, index = {}, syn_1_error_100 = {}, syn_2_error_100 = {}, syn_1_error_500 = {}, syn_2_error_500 = {}, syn_1_error_1000 = {}, syn_2_error_1000 = {}, batch1_error = {}, batch2_error = {}, average_reward_SYN_1_100 = {}, average_reward_SYN_2_100 = {}, average_reward_SYN_1_500 = {}, average_reward_SYN_2_500 = {}, average_reward_SYN_1_1000 = {}, average_reward_SYN_2_1000 = {}, average_reward_1 = {}, average_reward_2 = {}\n'.format(
        steps, n, index, syn_1_error_100, syn_2_error_100, syn_1_error_500, syn_2_error_500, syn_1_error_1000, syn_2_error_1000, batch1_error, batch2_error, average_reward_SYN_1_100, average_reward_SYN_2_100, average_reward_SYN_1_500, average_reward_SYN_2_500, average_reward_SYN_1_1000, average_reward_SYN_2_1000,  average_reward_1, average_reward_2)

    return res

torch.set_printoptions(threshold=float('inf'))

_, env = d3rlpy.datasets.get_pendulum()
cql_pretrained = d3rlpy.load_learnable(f"cql_pretrained.d3")

with open('truevalue.txt') as file: 
    value = float(file.read())
UnderlyingTrue = torch.tensor(value)

res = main(env, cql_pretrained, num, n, UnderlyingTrue, index, steps)
print(res)

with open(f'outputs_mVariation.txt', 'a') as file:
    file.write(res)