import optax
from rl.mr_iqn_td3 import ModelRiskTD3PlusBC
from rl.codac import CODAC
from rl.oraac import ORAAC
from rl.iqn_td3_plusbc import QuantileTD3PlusBC
from airsims.airsim_processor import AirsimPreprocessor
import os
from misc.rng_modules import fix_seed
import fire
from copy import deepcopy
import pandas as pd
import numpy as np


def test_eval(model, test_env, n_eval: int = 1000):
    score = model.evaluate(n_eval, test_env)
    neg_risk, _ = model.score_metric(score)
    mean, std = np.mean(score), np.std(score)
    return mean, std, neg_risk


class RunExperiments(object):
    def __init__(self,
                 ip: str,
                 data_set_path: str = '/hard_v2.pkl',
                 risk_type: str = 'wang',
                 log_file_name: str = 'airsim_hard_normed_cmp.csv',
                 risk_eta: float = -1.0,
                 epoch_length: int = int(1e+5),
                 n_critics: int = 5,
                 gpu: int | str = 1,
                 seed: int = 0,
                 n_final_eval: int = 100,
                 gamma: float = 0.96,
                 hard: bool = True,
                 evaluation: bool = False,
                 ):
        os.environ['CUDA_VISIBLE_DEVICES'] = f'{gpu}'
        os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
        self.ip = ip
        self.hard = hard
        self.data_set_path = data_set_path
        self.data_set_path = f'{os.getcwd()}/{data_set_path}'
        self.risk_type = risk_type
        self.risk_eta = risk_eta
        self.seed = seed
        self.epoch_length = epoch_length
        self.n_critics = n_critics
        self.logs = []
        try:
            self.previous_log = pd.read_csv(log_file_name, )
        except FileNotFoundError:
            self.previous_log = pd.DataFrame()
        self.log_file_name = log_file_name
        self.opt_class = optax.adam
        self.eval = evaluation
        self.n_final_eval = n_final_eval
        self.gamma = gamma

    def build_processor(self, normalize_obs: bool = True):
        processor = AirsimPreprocessor(ip=self.ip,
                                       path=self.data_set_path,
                                       normalize_obs=normalize_obs,
                                       hard=self.hard,
                                       seed=2 ** 31 - self.seed
                                       )
        return processor

    def post_learning(self, model, log, name):
        log['seed'] = self.seed
        log['model'] = name

        self.logs.append(deepcopy(log))
        current_log = pd.DataFrame(self.logs)
        new_data = pd.concat([self.previous_log, current_log])
        new_data.to_csv(self.log_file_name, index=False)
        model.save(path=f'{os.getcwd()}/models/{name}_{self.seed}')

    def evaluate_model(self, model, processor, name):
        self.load_model(model, name=name)
        test_mean, test_std, test_neg_risk = test_eval(model, processor.test_env, n_eval=1000)
        print(f"TEST {test_mean} +/- {test_std}, Risk{model.risk_type}@{model.risk_eta} {test_neg_risk}")

    def mr_iqn(self, learning_rate: float = 3e-4,
               ff_feature: bool = True,
               smooth: bool = True,
               tqc: bool = True):
        fix_seed(self.seed)
        print(f"SEED {self.seed}")
        print("TRAIN MR.")
        print("*****************************************************************")
        name = 'MR_TD3'
        processor = self.build_processor(True)
        model = ModelRiskTD3PlusBC(env=processor.env,
                                   buffer=processor.get_replay_buffer(),
                                   risk_type=self.risk_type,
                                   risk_eta=self.risk_eta,
                                   n_critics=self.n_critics,
                                   smooth=smooth,
                                   opt_class=self.opt_class,
                                   fourier_feature_critic=ff_feature,
                                   learning_rate=learning_rate,
                                   tqc=not tqc,
                                   seed=self.seed,
                                   gamma=self.gamma,
                                   q_learning_scale=1,
                                   )
        log = model.train_airsim(epoch=1, len_epoch=self.epoch_length, final_eval=self.n_final_eval)
        self.post_learning(model, log, name)

    def iqn_td3_plus_bc(self, ff_feature: bool = True, smooth: bool = True, tqc: bool = True, ):
        fix_seed(self.seed)
        print(f"SEED {self.seed}")
        print("TRAIN TD3_PLUS_BC.")
        print("*****************************************************************")

        processor = self.build_processor(True)
        model = QuantileTD3PlusBC(env=processor.env,
                                  buffer=processor.get_replay_buffer(),
                                  risk_type=self.risk_type,
                                  risk_eta=self.risk_eta,
                                  n_critics=self.n_critics,
                                  opt_class=self.opt_class,
                                  smooth=smooth,
                                  fourier_feature_critic=ff_feature,
                                  tqc=not tqc,
                                  seed=self.seed,
                                  gamma=self.gamma,
                                  q_learning_scale=1.
                                  )
        log = model.train_airsim(epoch=1, len_epoch=self.epoch_length, final_eval=self.n_final_eval)
        self.post_learning(model, log, name='TD3')

    def codac(self,
              normalize_obs: bool = False,
              auto_adjust_kl: bool = True,
              target_kl: float = 10
              ):
        fix_seed(self.seed)
        print(f"SEED {self.seed}")
        print("TRAIN CODAC.")
        print("*****************************************************************")
        processor = self.build_processor(normalize_obs)
        model = CODAC(env=processor.env,
                      buffer=processor.get_replay_buffer(),
                      risk_type=self.risk_type, risk_eta=self.risk_eta,
                      n_critics=self.n_critics,
                      auto_adjust_kl=auto_adjust_kl,
                      target_diff=target_kl,
                      opt_class=self.opt_class,
                      seed=self.seed,
                      gamma=self.gamma
                      )
        log = model.train_airsim(epoch=1, len_epoch=self.epoch_length, final_eval=self.n_final_eval)
        self.post_learning(model, log, name='CODAC')

    def oraac(self, normalize_obs: bool = False):
        fix_seed(self.seed)
        print("TRAIN ORAAC")
        print("*****************************************************************")
        processor = self.build_processor(normalize_obs)
        model = ORAAC(env=processor.env,
                      buffer=processor.get_replay_buffer(),
                      risk_type=self.risk_type, risk_eta=self.risk_eta,
                      n_critics=self.n_critics,
                      opt_class=self.opt_class,
                      gamma=self.gamma,
                      seed=self.seed)
        log = model.train_airsim(epoch=1, len_epoch=self.epoch_length, final_eval=self.n_final_eval)
        self.post_learning(model, log, name='ORAAC')

    def save_model(self, model, path):
        model.save(path)

    def load_model(self, model, name):
        path = f'{os.getcwd()}/models/{name}_airsim_{self.seed}'
        model.load(path)


if __name__ == '__main__':
    fire.Fire(RunExperiments)
