import d3rlpy
from d3rlpy.datasets import get_atari
from d3rlpy.metrics.scorer import evaluate_on_environment
from d3rlpy.metrics.scorer import td_error_scorer
from d3rlpy.metrics.scorer import discounted_sum_of_advantage_scorer
from d3rlpy.metrics.scorer import average_value_estimation_scorer
from d3rlpy.algos.torch.sac_impl import SACImpl
from d3rlpy.algos.torch.cql_impl import CQLImpl
from sklearn.model_selection import train_test_split
import argparse
import csv
import numpy as np
import torch
from typing import Any, Dict, Iterator, List, Optional, Sequence, Union


def main(args):
    dataset, env = d3rlpy.datasets.get_d4rl(args.dataset)
    d3rlpy.seed(args.seed)
    train_episodes, test_episodes = train_test_split(dataset, test_size=0.2) 
    cql = d3rlpy.algos.BC.from_json(args.model, use_gpu=True)
    cql.fit(train_episodes,
            eval_episodes=test_episodes,
            n_steps=500000,
            n_steps_per_epoch=5000,
            experiment_name=args.exp_name,
            logdir=args.dataset,
            tensorboard_dir='./run/' + args.dataset,
            scorers={
                'environment': evaluate_on_environment(env),
                'td_error': td_error_scorer,
                'discounted_advantage': discounted_sum_of_advantage_scorer,
                'value_scale': average_value_estimation_scorer
            })


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    # parser.add_argument('--dataset', type=str, default='hopper-medium-v0')
    # parser.add_argument('--model', type=str, default='./hopper_bc_m_params.json')
    parser.add_argument('--dataset', type=str, default='halfcheetah-medium-v0')
    parser.add_argument('--model', type=str, default='./half_bc_m_params.json')
    # parser.add_argument('--dataset', type=str, default='walker2d-medium-v0')
    # parser.add_argument('--model', type=str, default='./walk_bc_m_params.json')
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--gpu', type=int, default=0)
    parser.add_argument('--exp_name', type=str, default=None)
    args = parser.parse_args()
    main(args)
    print("end train")

