import gym
import numpy as np
import argparse
import csv
from d3rlpy.algos import CQL
from d3rlpy.algos import BCQ, BC, BEAR
from d3rlpy.metrics.scorer import evaluate_on_environment
import d3rlpy

def hopper(out_name):
    env = gym.make('hopper-medium-v0')
    scorer = evaluate_on_environment(env)
    cql = CQL.from_json('./hopper_cql_CPP_20231213232816/params.json')
    cql.load_model('./hopper_cql_20231213232816/model.pt')    

    score_list = []
    with open(out_name, 'w', newline='') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow(["Trial", "Score"])  
        
        for i in range(50):
            score = scorer(cql)
            score_list.append(score)
            writer.writerow([i, score]) 
            print(score_list)
        
        score_list_ = np.array(score_list)
        mean_score = np.mean(score_list_)
        
        writer.writerow(["Mean", mean_score])
        
    score_list_ = np.array(score_list)
    print(score_list_, np.mean(score_list))
    
def half(out_name):
    env = gym.make('halfcheetah-medium-v0')
    scorer = evaluate_on_environment(env)
    cql = CQL.from_json('./half_cql_CPP_20241214232816/params.json')
    cql.load_model('./half_cql_20231214232816/model.pt')    

    score_list = []
    with open(out_name, 'w', newline='') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow(["Trial", "Score"])  
        
        for i in range(50):
            score = scorer(cql)
            score_list.append(score)
            writer.writerow([i, score]) 
            print(score_list)
        
        score_list_ = np.array(score_list)
        mean_score = np.mean(score_list_)
        
        writer.writerow(["Mean", mean_score])
        
    score_list_ = np.array(score_list)
    print(score_list_, np.mean(score_list))

def walker2d(out_name):
    env = gym.make('Walker2d-v2')
    scorer = evaluate_on_environment(env)
    cql = CQL.from_json('./walk2d_cql_CPP_20240113232816/params.json')
    cql.load_model('./walk2d_cql_20240113232816/model.pt')    

    score_list = []
    with open(out_name, 'w', newline='') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow(["Trial", "Score"])  
        
        for i in range(50):
            score = scorer(cql)
            score_list.append(score)
            writer.writerow([i, score]) 
            print(score_list)
        
        score_list_ = np.array(score_list)
        mean_score = np.mean(score_list_)
        
        writer.writerow(["Mean", mean_score])
        
    score_list_ = np.array(score_list)
    print(score_list_, np.mean(score_list))

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--out_name', type=str, required=True)
    args = parser.parse_args()
    
    # hopper(args.out_name)
    # half(args.out_name)
    walker2d(args.out_name)    
