from rl.mr_iqn_td3 import ModelRiskTD3PlusBC
from rl.iqn_td3_plusbc import QuantileTD3PlusBC
from rl.utils.d4rl_wrapper import D4RLPreprocessor
import os
from misc.rng_modules import fix_seed
import fire
from copy import deepcopy
import pandas as pd
import optax


class RunExperiments(object):
    def __init__(self,
                 env_id: str = 'hopper-medium-replay-v2',
                 risk_type: str = 'wang',
                 log_file_name: str = 'd4rl_TEST.csv',
                 risk_eta: float = -0.5,
                 epoch_length: int = int(1e+6),
                 n_critics: int = 3,
                 smooth: bool = True,
                 gpu: int | str = 1,
                 seed: int = 0,
                 ):
        os.environ['CUDA_VISIBLE_DEVICES'] = f'{gpu}'
        os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
        self.env_id = env_id
        self.risk_type = risk_type
        self.risk_eta = risk_eta
        self.seed = seed
        self.epoch_length = epoch_length
        self.n_critics = n_critics
        self.logs = []
        self.smooth = smooth
        try:
            self.previous_log = pd.read_csv(log_file_name, )
        except FileNotFoundError:
            self.previous_log = pd.DataFrame()
        self.log_file_name = log_file_name

    @property
    def processor(self):
        return D4RLPreprocessor(env_id=self.env_id, normalize_reward=False, seed=self.seed)

    def mr_td3_plus_bc(self,
                       ):
        fix_seed(self.seed)
        print(f"SEED {self.seed}")
        print("TRAIN MR.")
        print("*****************************************************************")
        processor = self.processor
        model = ModelRiskTD3PlusBC(env=processor.env,
                                   buffer=processor.get_replay_buffer(),
                                   risk_type=self.risk_type,
                                   risk_eta=self.risk_eta,
                                   n_critics=self.n_critics,
                                   smooth=True,
                                   fourier_feature_critic=True,
                                   learning_rate=3e-4,
                                   opt_class=optax.adam,
                                   q_learning_scale=5,
                                   seed=self.seed)

        log = model.train(epoch=1, len_epoch=self.epoch_length,
                          n_eval=10, final_eval=1000,
                          eval_interval=1, normalizer=processor.normalized_score)
        log['n_critics'] = self.n_critics
        log['seed'] = self.seed
        log['epoch_length'] = self.epoch_length
        log['model'] = 'MR_TD3_plus_bc'
        self.logs.append(deepcopy(log))
        current_log = pd.DataFrame(self.logs)
        new_data = pd.concat([self.previous_log, current_log])
        new_data.to_csv(self.log_file_name, index=False)
        model.save(path=f'{os.getcwd()}/models/{self.env_id}_MR_{self.seed}')

    def iqn_td3_plus_bc(self):
        fix_seed(self.seed)
        print(f"SEED {self.seed}")
        print("TRAIN TD3_PLUS_BC.")
        print("*****************************************************************")
        processor = D4RLPreprocessor(env_id=self.env_id,
                                     normalize_reward=False,
                                     )

        model = QuantileTD3PlusBC(env=processor.env,
                                  buffer=processor.get_replay_buffer(),
                                  risk_type=self.risk_type,
                                  risk_eta=self.risk_eta,
                                  n_critics=self.n_critics,
                                  smooth=self.smooth,
                                  q_learning_scale=5,
                                  seed=self.seed)
        log = model.train(epoch=100, len_epoch=self.epoch_length, n_eval=10, final_eval=1000,
                          eval_interval=1, normalizer=processor.normalized_score)
        log['n_critics'] = self.n_critics
        log['seed'] = self.seed
        log['epoch_length'] = self.epoch_length
        log['model'] = 'TD3_PLUS_BC'
        self.logs.append(deepcopy(log))

        current_log = pd.DataFrame(self.logs)
        new_data = pd.concat([self.previous_log, current_log])
        new_data.to_csv(self.log_file_name, index=False)


if __name__ == '__main__':
    fire.Fire(RunExperiments)
