'''
This script converts the csv files in the folder to MDPDataset and train an offline agent. Alternatively, it can also load the MDPDataset from a file and train the agent.
Besides, you can visualize the training process in tensorboard by running: tensorboard --logdir=offline_runs
Furthermore, you can play with different hyperparameters like weight_temp, n_epochs, etc. by changing the arguments in the main function.
'''
import wandb
wandb.tensorboard.patch(root_logdir="offline_runs")
wandb.init()

import csv
from d3rlpy.dataset import MDPDataset
import numpy as np
from models import *
from d3rlpy.models.encoders import VectorEncoderFactory
import numpy as np
import datetime
import re
from d3rlpy.datasets import MDPDataset
from d3rlpy.algos import CQL, IQL, BC, TD3PlusBC
from d3rlpy.algos import DiscreteCQL, DiscreteSAC, DiscreteBCQ, DiscreteBC, DoubleDQN, DQN
from d3rlpy.metrics.scorer import discounted_sum_of_advantage_scorer
from d3rlpy.metrics.scorer import evaluate_on_environment
from d3rlpy.metrics.scorer import td_error_scorer
from d3rlpy.metrics.scorer import average_value_estimation_scorer
from sklearn.model_selection import train_test_split
from d3rlpy.preprocessing.scalers import StandardScaler
from d3rlpy.preprocessing.action_scalers import MinMaxActionScaler
from d3rlpy.preprocessing.reward_scalers import MinMaxRewardScaler
# reference: https://d3rlpy.readthedocs.io/en/v1.1.1/tutorials/play_with_mdp_dataset.html

def csv_to_mdp_dataset(csv_file):
    '''
    description: convert a csv file to MDPDataset
    input:
        csv_file: a csv file with columns: state, action, reward, terminal
    output:
        dataset: MDPDataset
    '''
    states = []
    actions = []
    rewards = []
    terminals = []

    with open(csv_file, 'r') as file:
        csv_reader = csv.reader(file)
        next(csv_reader)  # Skip the header row

        for row in csv_reader:
            state = list(map(float, row[0].strip('[]').split()))
            action = int(row[1])
            reward = float(row[2])
            terminal = eval(row[3])

            states.append(state)
            actions.append(action)
            rewards.append(reward)
            terminals.append(terminal)

    # change states into numpy
    states = np.array(states)
    dataset = MDPDataset(states, actions, rewards, terminals, discrete_action=True)
    return dataset


def train_offline_agent(dataset, algo = 'CQL', use_gpu = True, postfix = None, args_parse = None):
    '''
    description: train an offline agent
    input:
        dataset: MDPDataset
        algo: any discrete algorithm in d3rlpy
        use_gpu: whether to use gpu
        weight_temp: temperature for cont IQL (for other algorithms this arg is dummy)
        n_epochs: number of epochs to train
    output:
        None
        model saved in the current directory
    '''

    # parse in the hyperparameters
    if args_parse is not None:
        batch_size = args_parse.batch_size
        actor_learning_rate = args_parse.actor_learning_rate
        critic_learning_rate = args_parse.critic_learning_rate
        temp_learning_rate = args_parse.temp_learning_rate
        n_critics = args_parse.n_critics
        n_epochs = args_parse.n_epochs
        weight_temp = args_parse.weight_temp
        use_gpu = args_parse.use_gpu
    else:
        batch_size = 32
        actor_learning_rate = 0.0001
        critic_learning_rate = 0.0001
        learning_rate = 1.25e-4
        temp_learning_rate = 0.001
        n_critics = 6
        n_epochs = 50
        weight_temp = 100.0
        use_gpu = False

    # setup CQL algorithm
    if algo == "CQL":
        alg = DiscreteCQL(
            scaler=StandardScaler(dataset),
            actor_encoder_factory=VectorEncoderFactory(use_batch_norm = True, dropout_rate = 0.5, hidden_units=[16, 16]),
            critic_encoder_factory=VectorEncoderFactory(use_batch_norm = True, dropout_rate = 0.5, hidden_units=[16, 16]),
            use_gpu=use_gpu,
        )
    elif algo == "DiscreteBC":
        alg = DiscreteBC(
            scaler=StandardScaler(dataset),
            critic_learning_rate=critic_learning_rate,
            actor_learning_rate=actor_learning_rate,
            temp_learning_rate=temp_learning_rate,
            n_critics=n_critics,
            use_gpu=use_gpu,
            actor_encoder_factory=VectorEncoderFactory(use_batch_norm = True, dropout_rate = 0.5, hidden_units=[16, 16]),
            critic_encoder_factory=VectorEncoderFactory(use_batch_norm = True, dropout_rate = 0.5, hidden_units=[16, 16]),
        )
    elif algo == "DiscreteSAC":
        # https://d3rlpy.readthedocs.io/en/v0.41/references/generated/d3rlpy.algos.DiscreteSAC.html#
        alg = DiscreteSAC(
            scaler=StandardScaler(dataset),
            critic_learning_rate=critic_learning_rate,
            actor_learning_rate=actor_learning_rate,
            temp_learning_rate=temp_learning_rate,
            learning_rate = learning_rate,
            n_critics=n_critics,
            n_epochs=n_epochs,
            use_gpu=use_gpu,
        )
    elif algo == "DQN":
        # https://d3rlpy.readthedocs.io/en/v0.41/references/generated/d3rlpy.algos.DiscreteSAC.html#
        alg = DQN(
            scaler=StandardScaler(dataset),
            critic_learning_rate=critic_learning_rate,
            actor_learning_rate=actor_learning_rate,
            temp_learning_rate=temp_learning_rate,
            # learning_rate = learning_rate,
            # n_critics=n_critics,
            n_epochs=n_epochs,
            use_gpu=use_gpu,
        )
    elif algo == "DiscreteBCQ":
        alg = DiscreteBCQ(
            scaler=StandardScaler(dataset),
            critic_learning_rate=critic_learning_rate,
            actor_learning_rate=actor_learning_rate,
            temp_learning_rate=temp_learning_rate,
            learning_rate = learning_rate,
            # n_critics=n_critics,
            n_epochs=n_epochs,
            # actor_encoder_factory=VectorEncoderFactory(use_batch_norm = True, dropout_rate = 0.5, hidden_units=[16, 16]),
            # critic_encoder_factory=VectorEncoderFactory(use_batch_norm = True, dropout_rate = 0.5, hidden_units=[16, 16]),
            use_gpu=use_gpu,
        )
    elif algo == "DoubleDQN":
        alg = DoubleDQN(
            scaler=StandardScaler(dataset),
            # action_scaler=MinMaxActionScaler(dataset),
            # reward_scaler=MinMaxRewardScaler(dataset),
            use_gpu=use_gpu,
        )
    else:
        raise NotImplementedError("Algorithm not implemented or not suitable for discrete actions!")

    # split train and test episodes
    exp_name = "ALGO{}_TEMP{}_EPOCH{}{}".format(algo, weight_temp, n_epochs, postfix)
    wandb.init(project="offline-rl", name=exp_name)

    train_episodes, test_episodes = train_test_split(dataset, test_size=0.01)

    # start training

    # fitting and store the losses in the tensorboard
    alg.fit(train_episodes,
            # eval_episodes=test_episodes,
            n_epochs=n_epochs,
            scorers={
                # 'environment': evaluate_on_environment(env),
                # 'advantage': discounted_sum_of_advantage_scorer, # smaller is better
                'td_error': td_error_scorer, # smaller is better
                'value_scale': average_value_estimation_scorer # smaller is better
            }, tensorboard_dir='offline_runs/runs_{}'.format(exp_name))

    # evaluate the model's reward on the test dataset
    # test_score = alg._evaluate(test_episodes) # TODO (LISA): this is not working, need another way to evaluate
    # print("Test score:", test_score)
    # make a model directory
    os.makedirs("models", exist_ok=True)
    alg.save_model("models/rl_model_{}.pt".format(exp_name))
    print("Model saved!", "models/rl_model_{}.pt".format(exp_name))



if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--csv_file_folder', type=str, default="../trajectories")
    parser.add_argument('--algo', type=str, default="DiscreteBC", help="CQL, DiscreteBC, DiscreteSAC, DiscreteBCQ, DoubleDQN")
    parser.add_argument('--use_gpu', type=bool, default=False)
    parser.add_argument('--weight_temp', type=float, default=100.0)
    parser.add_argument('--n_epochs', type=int, default=40)
    parser.add_argument('--save_dataset', type=bool, default=True)
    parser.add_argument('--save_dataset_postfix', type=str, default='_full')
    parser.add_argument('--direct_load', type=bool, default=True)
    parser.add_argument('--heuristic', type=list, default=None)
    parser.add_argument('--load_dataset_name', type=str, default='../trajectories/offline_trajectories_full.h5')
    # parse in the hyperparameters
    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--actor_learning_rate', type=float, default=0.00001)
    parser.add_argument('--critic_learning_rate', type=float, default=0.00001)
    parser.add_argument('--temp_learning_rate', type=float, default=0.0001)
    parser.add_argument('--postfix_exp', type=str, default='_fulldata')
    parser.add_argument('--n_critics', type=int, default=6)

    args = parser.parse_args()
    # datasets = []

    # read in all the files from the folder
    import os
    csv_files = []
    if args.direct_load:
        # If we already have the dataset, we can directly load it
        dataset = MDPDataset.load(args.load_dataset_name)
        print("Dataset loaded!")
        print("Dataset name: {}".format(args.load_dataset_name))
    else:
        # If not, we need to parse the csv files and save it
        for file in os.listdir(args.csv_file_folder):
            # Loop through all CSV files in the foler
            if file.endswith(".csv"):
                csv_files.append(os.path.join(args.csv_file_folder, file))
            # if there is a specific requirement on the heuristic
            if args.heuristic is not None:
                # a list of candidate policies to be trained on
                # split the string of filename by underscore, the third part is the heuristic
                heuristic = file.split('_')[2]
                import pdb; pdb.set_trace()
                if heuristic not in args.heuristic:
                    continue
                else:
                    csv_files.append(os.path.join(args.csv_file_folder, file))
        for csvf in  csv_files:
            try:
                dataset.extend(csv_to_mdp_dataset(csvf))
                print("Extending dataset...{}".format(csvf))
            except:
                # First trajectory
                dataset = csv_to_mdp_dataset(csvf)

        if args.save_dataset:
            # this is a list of MDPDataset
            dataset.dump('../trajectories/offline_trajectories{}.h5'.format(args.save_dataset_postfix))
            print("Dataset ../trajectories/offline_trajectories{}.h5 saved!".format(args.save_dataset_postfix))
        else:
            input("Press Enter to continue...")
    # train the model
    # spin up new threads to handle the training
    import subprocess



    # train_offline_agent(dataset, algo = 'CQL', use_gpu = False, weight_temp = 100.0, n_epochs = 2)
    train_offline_agent(dataset, algo = args.algo, use_gpu = False, postfix = args.postfix_exp)
    # train_offline_agent(dataset, algo = 'DiscreteSAC', use_gpu = True, weight_temp = 100.0, n_epochs = 2)
    wandb.finish()


