from point_mass.pt_mass_wrapper import PointMassProcessor
from rl.iqn_td3_plusbc import QuantileTD3PlusBC
from rl.codac import CODAC
from rl.oraac import ORAAC
from rl.mr_iqn_td3 import ModelRiskTD3PlusBC
from functools import partial
import optax
from tqdm import trange
from misc.rng_modules import fix_seed
import numpy as np


class Main(object):
    model_class_map = {"MR": partial(ModelRiskTD3PlusBC, smooth=True, fourier_feature_critic=True),
                       "IQN": QuantileTD3PlusBC,
                       "CODAC": CODAC,
                       "ORAAC": ORAAC,
                       }

    def __init__(self,
                 model_type: str,
                 risk_type: str = 'cvar',
                 risk_eta: float = 0.1,
                 n_critics: int = 5,
                 epoch: int = int(2e+5),
                 learning_rate: float = 3e-4,
                 env_risk_prob: float = 0.9,
                 risk_var: float = 200,
                 *,
                 normalize_obs: bool | None = None,
                 seed: int = 0,
                 ):
        self.type_name = model_type
        self.risk_var = int(risk_var)
        self.model_type = self.model_class_map[model_type]
        if normalize_obs is None:
            if model_type in ['MR', 'IQN', 'CMVMR', 'QLD']:
                normalize_obs = True
            else:
                normalize_obs = False
        self.normalize_obs = normalize_obs
        option = self.default_options(model_type)
        self.processor = PointMassProcessor(path='./point_mass/pt_mass_data.pkl',
                                            seed=seed, risk_prob=env_risk_prob,
                                            risk_var=risk_var,
                                            normalize_obs=self.normalize_obs, box2d=True)
        self.epoch = epoch

        self.model = self.model_type(
            env=self.processor.env,
            buffer=self.processor.buffer,
            risk_type=risk_type,
            risk_eta=risk_eta,
            learning_rate=learning_rate,
            n_critics=n_critics,
            opt_class=optax.adabelief,
            seed=seed,
            **option
        )

        self.seed = seed
        self.learn()
        self.collect_results()
        self.model.save(f'{os.getcwd()}/pt_mass_model/{self.model_type}_{self.risk_var}_{seed}')

    @staticmethod
    def default_options(model_type: str):
        option_map = {"MR": {"smooth": True, "q_learning_scale": 5},
                      "IQN": {"smooth": True, "q_learning_scale": 5, "fourier_feature_critic": True},
                      "CMVMR": {"smooth": True, "q_learning_scale": 2},
                      'QLD': {'smooth': True, 'q_learning_scale': 1},
                      "CODAC": {"auto_adjust_kl": True, },
                      "ORAAC": {"phi": 0.25}
                      }
        return option_map[model_type]

    def learn(self, ):
        self.model.train(epoch=1, len_epoch=self.epoch, eval_interval=1, n_eval=10, final_eval=10)

    def collect_results(self):
        fix_seed(self.seed)

        env = self.processor.env
        scores = []
        pair_data = {"obs": [], "action": [], "start": [], "done": []}
        for _ in trange(100):
            obs, info = env.reset()
            done = False
            score = 0
            pair_data['start'].append(True)
            while not done:
                action = self.model.predict(obs)
                if self.normalize_obs:
                    pair_data["obs"].append(info['denormalized_obs'].copy())
                else:
                    pair_data['obs'].append(obs.copy())
                pair_data["action"].append(action.copy())
                next_obs, reward, done, timeout, info = env.step(action)
                pair_data['done'].append(done or timeout)
                obs = next_obs
                done = done or timeout
                score += reward
                pair_data['start'].append(False)
            pair_data['start'].pop()
            scores.append(score)
        print(f'PATH: obs_action_{self.type_name}_{int(self.risk_var)}')
        np.savez(f"obs_action_{self.type_name}_{int(self.risk_var)}.npz", **pair_data)


if __name__ == '__main__':
    import os
    gpu = 1
    os.environ['CUDA_VISIBLE_DEVICES'] = f'{gpu}'
    os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'

    types = ["MR", "ORAAC", "CODAC"]
    for s in [0]:
        for r in [5, 100]:
            for ty in types:
                print(f"Model:{ty}\tRisk:{r}\tSeed:{s}")
                Main(
                    model_type=ty,
                    risk_type='cvar',
                    risk_eta=0.1,
                    risk_var=r,
                    seed=s
                )

