import torch
import numpy as np
# print('-----------------------------------------------------------------')
# total_radius = []
# for i in range(1):
    # radius = []
    # r = torch.load('PongNoFrameskip-v4_denoiser_0.1/test.log_nr-1_clean_m-100_sigma-' + str(0.1) + '_clip_nat_R-list-' + str(i+1) + '.pt')
    # for t in r:
        # t = list(t)
        # if t[2] < 0:
            # t[2] = 0
        # radius.append(t[2])
    # total_radius.append(radius)
# print(np.mean(np.asarray(total_radius)))

# total_radius = []
# for i in range(1):
    # radius = []
    # r = torch.load('FreewayNoFrameskip-v4_denoiser_0.12/test.log_nr-1_clean_m-100_sigma-' + str(0.12) + '_clip_nat_R-list-' + str(i+1) + '.pt')
    # for t in r:
        # t = list(t)
        # if t[2] < 0:
            # t[2] = 0
        # radius.append(t[2])
    # total_radius.append(radius)
# print(np.mean(np.asarray(total_radius)))

# total_radius = []
# for i in range(1):
    # radius = []
    # r = torch.load('RoadRunnerNoFrameskip-v4_denoiser_0.1/test.log_nr-1_clean_m-100_sigma-' + str(0.1) + '_clip_nat_R-list-' + str(i+1) + '.pt')
    # for t in r:
        # t = list(t)
        # if t[2] < 0:
            # t[2] = 0
        # radius.append(t[2])
    # total_radius.append(radius)
# print(np.mean(np.asarray(total_radius)))
# print("---------------------------------------------------------------------------------------")
# total_radius = []
# for i in range(1):
    # radius = []
    # r = torch.load('PongNoFrameskip-v4_rs_0.1/test.log_nr-1_clean_m-100_sigma-' + str(0.1) + '_clip_rs_R-list-' + str(i+1) + '.pt')
    # for t in r:
        # t = list(t)
        # if t[2] < 0:
            # t[2] = 0
        # radius.append(t[2])
    # total_radius.append(radius)
# print(np.mean(np.asarray(total_radius)))

# total_radius = []
# for i in range(1):
    # radius = []
    # r = torch.load('FreewayNoFrameskip-v4_rs_0.12/test.log_nr-1_clean_m-100_sigma-' + str(0.12) + '_clip_rs_R-list-' + str(i+1) + '.pt')
    # for t in r:
        # t = list(t)
        # if t[2] < 0:
            # t[2] = 0
        # radius.append(t[2])
    # total_radius.append(radius)
# print(np.mean(np.asarray(total_radius)))

# total_radius = []
# for i in range(1):
    # radius = []
    # r = torch.load('RoadRunnerNoFrameskip-v4_rs_0.1/test.log_nr-1_clean_m-100_sigma-' + str(0.1) + '_clip_rs_R-list-' + str(i+1) + '.pt')
    # for t in r:
        # t = list(t)
        # if t[2] < 0:
            # t[2] = 0
        # radius.append(t[2])
    # total_radius.append(radius)
# print(np.mean(np.asarray(total_radius)))
# print("---------------------------------------------------------------------------------------")
# total_radius = []
# for i in range(1):
    # radius = []
    # r = torch.load('PongNoFrameskip-v4_cov_denoiser_0.1/test.log_nr-1_clean_m-100_sigma-' + str(0.1) + '_clip_cov_R-list-' + str(i+1) + '.pt')
    # for t in r:
        # t = list(t)
        # if t[2] < 0:
            # t[2] = 0
        # radius.append(t[2])
    # total_radius.append(radius)
# print(np.mean(np.asarray(total_radius)))

# total_radius = []
# for i in range(1):
    # radius = []
    # r = torch.load('FreewayNoFrameskip-v4_cov_denoiser_0.12/test.log_nr-1_clean_m-100_sigma-' + str(0.12) + '_clip_cov_R-list-' + str(i+1) + '.pt')
    # for t in r:
        # t = list(t)
        # if t[2] < 0:
            # t[2] = 0
        # radius.append(t[2])
    # total_radius.append(radius)
# print(np.mean(np.asarray(total_radius)))

# total_radius = []
# for i in range(1):
    # radius = []
    # r = torch.load('RoadRunnerNoFrameskip-v4_cov_denoiser_0.1/test.log_nr-1_clean_m-100_sigma-' + str(0.1) + '_clip_cov_R-list-' + str(i+1) + '.pt')
    # for t in r:
        # t = list(t)
        # if t[2] < 0:
            # t[2] = 0
        # radius.append(t[2])
    # total_radius.append(radius)
# print(np.mean(np.asarray(total_radius)))
# print("---------------------------------------------------------------------------------------")
# total_radius = []
# for i in range(1):
    # radius = []
    # r = torch.load('PongNoFrameskip-v4_rad_denoiser_0.1/test.log_nr-1_clean_m-100_sigma-' + str(0.1) + '_clip_rad_R-list-' + str(i+1) + '.pt')
    # for t in r:
        # t = list(t)
        # if t[2] < 0:
            # t[2] = 0
        # radius.append(t[2])
    # total_radius.append(radius)
# print(np.mean(np.asarray(total_radius)))

# total_radius = []
# for i in range(1):
    # radius = []
    # r = torch.load('FreewayNoFrameskip-v4_rad_denoiser_0.12/test.log_nr-1_clean_m-100_sigma-' + str(0.12) + '_clip_rad_R-list-' + str(i+1) + '.pt')
    # for t in r:
        # t = list(t)
        # if t[2] < 0:
            # t[2] = 0
        # radius.append(t[2])
    # total_radius.append(radius)
# print(np.mean(np.asarray(total_radius)))

# total_radius = []
# for i in range(1):
    # radius = []
    # r = torch.load('RoadRunnerNoFrameskip-v4_rad_denoiser_0.1/test.log_nr-1_clean_m-100_sigma-' + str(0.1) + '_clip_rad_R-list-' + str(i+1) + '.pt')
    # for t in r:
        # t = list(t)
        # if t[2] < 0:
            # t[2] = 0
        # radius.append(t[2])
    # total_radius.append(radius)
# print(np.mean(np.asarray(total_radius)))
# print("---------------------------------------------------------------------------------------")
# total_radius = []
# for i in range(1):
#     radius = []
#     r = torch.load('PongNoFrameskip-v4_denoiser_adv_0.1/test.log_nr-1_clean_m-100_sigma-' + str(0.1) + '_clip_nat_R-list-' + str(i+1) + '.pt')
#     for t in r:
#         t = list(t)
#         if t[2] < 0:
#             t[2] = 0
#         radius.append(t[2])
#     total_radius.append(radius)
# print(np.mean(np.asarray(total_radius)))
#
# total_radius = []
# for i in range(1):
#     radius = []
#     r = torch.load('FreewayNoFrameskip-v4_denoiser_adv_0.12/test.log_nr-1_clean_m-100_sigma-' + str(0.12) + '_clip_nat_R-list-' + str(i+1) + '.pt')
#     for t in r:
#         t = list(t)
#         if t[2] < 0:
#             t[2] = 0
#         radius.append(t[2])
#     total_radius.append(radius)
# print(np.mean(np.asarray(total_radius)))
#
# total_radius = []
# for i in range(1):
#     radius = []
#     r = torch.load('RoadRunnerNoFrameskip-v4_denoiser_adv_0.1/test.log_nr-1_clean_m-100_sigma-' + str(0.1) + '_clip_nat_R-list-' + str(i+1) + '.pt')
#     for t in r:
#         t = list(t)
#         if t[2] < 0:
#             t[2] = 0
#         radius.append(t[2])
#     total_radius.append(radius)
# print(np.mean(np.asarray(total_radius)))
print("---------------------------------------------------------------------------------------")
total_radius = []
for i in range(1):
    radius = []
    r = torch.load('PongNoFrameskip-v4_rad/test.log_nr-1_clean_m-100_sigma-' + str(0.1) + '_clip_rad_R-list-' + str(i+1) + '.pt')
    for t in r:
        t = list(t)
        if t[2] < 0:
            t[2] = 0
        radius.append(t[2])
    total_radius.append(radius)
print(np.mean(np.asarray(total_radius)))

total_radius = []
for i in range(1):
    radius = []
    r = torch.load('FreewayNoFrameskip-v4_rad/test.log_nr-1_clean_m-100_sigma-' + str(0.12) + '_clip_rad_R-list-' + str(i+1) + '.pt')
    for t in r:
        t = list(t)
        if t[2] < 0:
            t[2] = 0
        radius.append(t[2])
    total_radius.append(radius)
print(np.mean(np.asarray(total_radius)))

total_radius = []
for i in range(1):
    radius = []
    r = torch.load('RoadRunnerNoFrameskip-v4_rad/test.log_nr-1_clean_m-100_sigma-' + str(0.1) + '_clip_rad_R-list-' + str(i+1) + '.pt')
    for t in r:
        t = list(t)
        if t[2] < 0:
            t[2] = 0
        radius.append(t[2])
    total_radius.append(radius)
print(np.mean(np.asarray(total_radius)))
print("---------------------------------------------------------------------------------------")
# total_radius = []
# for i in range(1):
#     radius = []
#     r = torch.load('PongNoFrameskip-v4_cov/test.log_nr-1_clean_m-100_sigma-' + str(0.1) + '_clip_cov_R-list-' + str(i+1) + '.pt')
#     for t in r:
#         t = list(t)
#         if t[2] < 0:
#             t[2] = 0
#         radius.append(t[2])
#     total_radius.append(radius)
# print(np.mean(np.asarray(total_radius)))

# total_radius = []
# for i in range(1):
#     radius = []
#     r = torch.load('FreewayNoFrameskip-v4_cov/test.log_nr-1_clean_m-100_sigma-' + str(0.12) + '_clip_cov_R-list-' + str(i+1) + '.pt')
#     for t in r:
#         t = list(t)
#         if t[2] < 0:
#             t[2] = 0
#         radius.append(t[2])
#     total_radius.append(radius)
# print(np.mean(np.asarray(total_radius)))

# total_radius = []
# for i in range(1):
#     radius = []
#     r = torch.load('RoadRunnerNoFrameskip-v4_cov/test.log_nr-1_clean_m-100_sigma-' + str(0.1) + '_clip_cov_R-list-' + str(i+1) + '.pt')
#     for t in r:
#         t = list(t)
#         if t[2] < 0:
#             t[2] = 0
#         radius.append(t[2])
#     total_radius.append(radius)
# print(np.mean(np.asarray(total_radius)))
print("---------------------------------------------------------------------------------------")
# total_radius = []
# for i in range(1):
#     radius = []
#     r = torch.load('PongNoFrameskip-v4/test.log_nr-1_clean_m-100_sigma-' + str(0.1) + '_clip_nat_R-list-' + str(i+1) + '.pt')
#     for t in r:
#         t = list(t)
#         if t[2] < 0:
#             t[2] = 0
#         radius.append(t[2])
#     total_radius.append(radius)
# print(np.mean(np.asarray(total_radius)))

# total_radius = []
# for i in range(1):
#     radius = []
#     r = torch.load('FreewayNoFrameskip-v4/test.log_nr-1_clean_m-100_sigma-' + str(0.12) + '_clip_nat_R-list-' + str(i+1) + '.pt')
#     for t in r:
#         t = list(t)
#         if t[2] < 0:
#             t[2] = 0
#         radius.append(t[2])
#     total_radius.append(radius)
# print(np.mean(np.asarray(total_radius)))

# total_radius = []
# for i in range(1):
#     radius = []
#     r = torch.load('RoadRunnerNoFrameskip-v4/test.log_nr-1_clean_m-100_sigma-' + str(0.1) + '_clip_nat_R-list-' + str(i+1) + '.pt')
#     for t in r:
#         t = list(t)
#         if t[2] < 0:
#             t[2] = 0
#         radius.append(t[2])
#     total_radius.append(radius)
# print(np.mean(np.asarray(total_radius)))