import optax
from rl.mr_iqn_td3 import ModelRiskTD3PlusBC
from rl.codac import CODAC
from rl.oraac import ORAAC
from rl.iqn_td3_plusbc import QuantileTD3PlusBC
from rl.utils.mtsim_wrapper import MTSimPreprocessor
import os
from misc.rng_modules import fix_seed
import fire
from copy import deepcopy
import pandas as pd
import numpy as np


def test_eval(model, test_env, n_eval: int = 1000):
    score = model.evaluate(n_eval, test_env)
    neg_risk, _ = model.score_metric(score)
    mean, std = np.mean(score), np.std(score)
    return mean, std, neg_risk


class RunExperiments(object):
    def __init__(self,
                 env_id: str = 'forex-hedge-v0',
                 data_set_path: str = '/forex_hedge_data.npz',
                 risk_type: str = 'wang',
                 log_file_name: str = 'finance_result.csv',
                 risk_eta: float = -0.5,
                 epoch_length: int = int(3e+5),
                 n_critics: int = 5,
                 gpu: int | str = 1,
                 seed: int = 0,
                 evaluation: bool = False
                 ):
        os.environ['CUDA_VISIBLE_DEVICES'] = f'{gpu}'
        os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
        self.env_id = env_id
        self.data_set_path = f'{os.getcwd()}/{data_set_path}'
        self.risk_type = risk_type
        self.risk_eta = risk_eta
        self.seed = seed
        self.epoch_length = epoch_length
        self.n_critics = n_critics
        self.logs = []
        try:
            self.previous_log = pd.read_csv(log_file_name, )
        except FileNotFoundError:
            self.previous_log = pd.DataFrame()
        self.log_file_name = log_file_name
        self.opt_class = optax.adam
        self.eval = evaluation

    def post_learning(self, model, processor, log, name):
        test_mean, test_std, test_neg_risk = test_eval(model, processor.env, n_eval=1000)
        log['n_critics'] = self.n_critics
        log['seed'] = self.seed
        log['epoch_length'] = self.epoch_length
        log['model'] = name
        log['test_mean'] = test_mean
        log['test_std'] = test_std
        log['test_neg_risk'] = test_neg_risk
        print(f"TEST {test_mean} +/- {test_std}, Risk{model.risk_type}@{model.risk_eta} {test_neg_risk}")

        self.logs.append(deepcopy(log))
        current_log = pd.DataFrame(self.logs)
        new_data = pd.concat([self.previous_log, current_log])
        new_data.to_csv(self.log_file_name, index=False)
        model.save(path=f'{os.getcwd()}/models/{name}_{self.seed}')

    def evaluate_model(self, model, processor, name):
        self.load_model(model, name=name)
        test_mean, test_std, test_neg_risk = test_eval(model, processor.env, n_eval=1000)
        print(f"TEST {test_mean} +/- {test_std}, Risk{model.risk_type}@{model.risk_eta} {test_neg_risk}")

    def mr_td3_plus_bc(self, learning_rate: float = 3e-4,
                       ff_feature: bool = True, smooth: bool = True, tqc: bool = True):
        fix_seed(self.seed)
        print(f"SEED {self.seed}")
        print("TRAIN MR.")
        print("*****************************************************************")
        print(f"SEED {self.seed}")
        print("*****************************************************************")
        print(f"Full Fourier? {ff_feature}",)
        print(f"Smoothing? {smooth}",)
        print(f"TQC? {tqc}")

        name = 'MR_TD3'
        processor = MTSimPreprocessor(env_id=self.env_id,
                                      path=self.data_set_path,
                                      seed=self.seed,
                                      normalize_reward=False)
        model = ModelRiskTD3PlusBC(env=processor.env,
                                   buffer=processor.get_replay_buffer(),
                                   risk_type=self.risk_type,
                                   risk_eta=self.risk_eta,
                                   n_critics=self.n_critics,
                                   smooth=smooth,
                                   opt_class=self.opt_class,
                                   fourier_feature_critic=ff_feature,
                                   learning_rate=learning_rate,
                                   tqc=not tqc,
                                   seed=self.seed)
        log = model.train(epoch=1, len_epoch=self.epoch_length,
                          n_eval=10, final_eval=1000,
                          eval_interval=1)
        self.post_learning(model, processor, log, name=name)

    def iqn_td3_plus_bc(self, ff_feature: bool = True, smooth: bool = True, tqc: bool = True,):
        fix_seed(self.seed)
        print(f"SEED {self.seed}")
        print("TRAIN TD3_PLUS_BC.")
        print("*****************************************************************")
        print(f"Full Fourier? {ff_feature}",)
        print(f"Smoothing? {smooth}",)
        print(f"TQC? {tqc}")

        processor = MTSimPreprocessor(env_id=self.env_id,
                                      path=self.data_set_path,
                                      seed=self.seed,
                                      normalize_reward=False)
        model = QuantileTD3PlusBC(env=processor.env,
                                  buffer=processor.get_replay_buffer(),
                                  risk_type=self.risk_type,
                                  risk_eta=self.risk_eta,
                                  n_critics=self.n_critics,
                                  opt_class=self.opt_class,
                                  smooth=smooth,
                                  fourier_feature_critic=ff_feature,
                                  tqc=not tqc,
                                  seed=self.seed)
        log = model.train(epoch=1, len_epoch=self.epoch_length, n_eval=10, final_eval=1000,
                          eval_interval=1)
        self.post_learning(model, processor, log, name='TD3')

    def codac(self,
              normalize_obs: bool = True,
              auto_adjust_kl: bool = True,
              target_kl: float = 10
              ):
        fix_seed(self.seed)
        print(f"SEED {self.seed}")
        print("TRAIN CODAC.")
        print("*****************************************************************")
        processor = MTSimPreprocessor(env_id=self.env_id,
                                      path=self.data_set_path,
                                      seed=self.seed,
                                      normalize_obs=normalize_obs,  # codac does not normalize observation
                                      normalize_reward=False)
        model = CODAC(env=processor.env,
                      buffer=processor.get_replay_buffer(),
                      risk_type=self.risk_type, risk_eta=self.risk_eta,
                      n_critics=self.n_critics,
                      auto_adjust_kl=auto_adjust_kl,
                      target_diff=target_kl,
                      opt_class=self.opt_class,
                      seed=self.seed)

        log = model.train(epoch=1, len_epoch=self.epoch_length, n_eval=10, final_eval=1000,
                          eval_interval=1)
        self.post_learning(model, processor, log, name='CODAC')

    def oraac(self, normalize_obs: bool = True):
        fix_seed(self.seed)
        print("TRAIN ORAAC")
        print("*****************************************************************")
        processor = MTSimPreprocessor(env_id=self.env_id,
                                      path=self.data_set_path,
                                      seed=self.seed,
                                      normalize_reward=normalize_obs,
                                      normalize_obs=False)
        model = ORAAC(env=processor.env,
                      buffer=processor.get_replay_buffer(),
                      risk_type=self.risk_type, risk_eta=self.risk_eta,
                      n_critics=self.n_critics,
                      opt_class=self.opt_class,
                      seed=self.seed)
        log = model.train(epoch=1, len_epoch=self.epoch_length, n_eval=10, final_eval=1000,
                          eval_interval=1)
        self.post_learning(model, processor, log, name='ORAAC')

    def save_model(self, model, path):
        model.save(path)

    def load_model(self, model, name):
        path = f'{os.getcwd()}/models/{name}_{self.seed}'
        model.load(path)


if __name__ == '__main__':
    fire.Fire(RunExperiments)
