import os
from risk_morl.utils.misc import fix_seed
import fire
import pandas as pd
from gymnasium.wrappers import RescaleAction, TimeLimit
import mo_gymnasium
from risk_morl.mo_sac_net import MOSAC, EWPSAC, MOSACAblation, MarginalMOSAC, MOSACAblationPositionalEncoding
from copy import deepcopy
from risk_morl.utils.env_util import MODummyVecEnv, reward_dim
from risk_morl.utils.mo_utils import test
from risk_morl.utils.risk_measures import RiskMeasureGenerator
from typing import Literal, Sequence


class SaveUtil(object):
    def __init__(self, folder, path, env_id=None):
        self.base_folder = folder
        self.env_id = env_id
        if self.env_id:
            self.folder = os.path.join(self.base_folder, self.env_id)
        else:
            self.folder = self.base_folder
        self.path = path

    def make_folder(self):
        os.makedirs(self.base_folder, exist_ok=True)
        if self.env_id:
            os.makedirs(self.folder, exist_ok=True)

    def save(self, model, results: pd.DataFrame):
        self.make_folder()
        model.save(os.path.join(self.folder, self.path) + ".pth")
        results.to_csv(os.path.join(self.folder, self.path + "result.csv"), index=False)

    def save_result(self, results: pd.DataFrame):
        self.make_folder()
        results.to_csv(os.path.join(self.folder, self.path + "result.csv"), index=False)

    def save_model(self, model):
        self.make_folder()
        model.save(os.path.join(self.folder, self.path) + ".pth")

    def load(self, model):
        return model.load(os.path.join(self.folder, self.path) + ".pth")


class Main(object):
    model: MOSAC

    def __init__(self,
                 env_name: str,
                 max_episode_steps: int = 500,
                 batch_size: int = 256,
                 gamma: float = 0.99,
                 num_env: int = 1,
                 folder: str = 'results',
                 num_eval_grid: int = 10,
                 num_eval_per_test: int = 100,
                 *,
                 seed: int = 42
                 ):
        fix_seed(seed)
        self.env_name = env_name
        self.env = MODummyVecEnv([lambda: self.wrap_env(mo_gymnasium.make(env_name,
                                                                          max_episode_steps=max_episode_steps))
                                  for i in range(num_env)])
        self.test_env = self.wrap_env(mo_gymnasium.make(env_name,
                                                        max_episode_steps=max_episode_steps))
        self.reward_dim = reward_dim(self.test_env)
        self.batch_size = batch_size
        self.gamma = gamma
        self.folder = folder
        self.seed = seed
        self.num_eval_grid = num_eval_grid
        self.num_eval_per_test = num_eval_per_test

    @staticmethod
    def wrap_env(env):
        return RescaleAction(env, min_action=-1., max_action=1.)

    def marginal(self,
                 learning_steps: int,
                 critic_lr: float = 3e-4,
                 actor_lr: float = 3e-4,
                 risk: Literal['cvar', 'wang', 'triangle'] = 'cvar',
                 risk_param: float = 0.5,
                 index: Sequence[int] = (-1,),
                 ent_coef: float | Literal['auto'] = 'auto',
                 truncation_lower: int = 0,
                 truncation_upper: int = 1,
                 ):
        risk_measure = RiskMeasureGenerator(alpha=risk_param,
                                            reward_dim=self.reward_dim, index=index,
                                            name=risk)()
        if risk == 'neutral':
            risk_name = 'neutral'
        else:
            risk_name = f'{risk}_{risk_param:.2f}'
        save_util = SaveUtil(self.folder, f'MarginalIQN_{risk_name}_{self.seed}', env_id=self.env_name)
        self.model = MarginalMOSAC(env=self.env, test_env=self.test_env, ent_coef=ent_coef,
                                   batch_size=self.batch_size, gamma=self.gamma, actor_lr=actor_lr,
                                   critic_lr=critic_lr, risk_measure=risk_measure, seed=self.seed,
                                   policy_kwargs={ "truncation_lower": truncation_lower,
                                                   "truncation_upper": truncation_upper
                                                   }
                                   )
        self.model.learn(learning_steps, log_interval=1, train_frequency=1)
        result = test(self.test_env, self.model, self.num_eval_grid,
                          self.reward_dim, 'marginal', self.num_eval_per_test)
        save_util.save(self.model, result)

    def kriqn(self,
              learning_steps: int,
              critic_lr: float = 3e-4,
              actor_lr: float = 3e-4,
              risk: Literal['cvar', 'wang', 'triangle'] = 'cvar',
              risk_param: float = 0.5,
              index: Sequence[int] = (-1,),
              ent_coef: float | Literal['auto'] = 'auto',
              truncation_lower: int = 0,
              truncation_upper: int = 1,
              ):
        risk_measure = RiskMeasureGenerator(alpha=risk_param,
                                            reward_dim=self.reward_dim, index=index,
                                            name=risk)()
        if risk == 'neutral':
            risk_name = 'neutral'
        else:
            risk_name = f'{risk}_{risk_param:.2f}'
        save_util = SaveUtil(self.folder, f'MOSAC_{risk_name}_{self.seed}', env_id=self.env_name)
        self.model = MOSAC(env=self.env, test_env=self.test_env, ent_coef=ent_coef,
                           batch_size=self.batch_size, gamma=self.gamma, actor_lr=actor_lr,
                           critic_lr=critic_lr, risk_measure=risk_measure, seed=self.seed,
                           policy_kwargs={ "truncation_lower": truncation_lower, "truncation_upper": truncation_upper }
                           )
        self.model.learn(learning_steps, log_interval=1, train_frequency=1)
        result = test(self.test_env, self.model, self.num_eval_grid,
                      self.reward_dim, 'kriqn', self.num_eval_per_test)
        save_util.save(self.model, result)

    def kriqn_ablation(self,
                       learning_steps: int,
                       critic_lr: float = 3e-4,
                       actor_lr: float = 3e-4,
                       risk: Literal['cvar', 'wang', 'triangle'] = 'cvar',
                       risk_param: float = 0.5,
                       index: Sequence[int] = (-1,),
                       ent_coef: float | Literal['auto'] = 'auto',
                       truncation_lower: int = 0,
                       truncation_upper: int = 1,
                       ):
        risk_measure = RiskMeasureGenerator(alpha=risk_param,
                                            reward_dim=self.reward_dim, index=index,
                                            name=risk)()
        if risk == 'neutral':
            risk_name = 'neutral'
        else:
            risk_name = f'{risk}_{risk_param:.2f}'
        save_util = SaveUtil(self.folder, f'MOSACAblation_{risk_name}_{self.seed}', env_id=self.env_name)
        self.model = MOSACAblation(env=self.env, test_env=self.test_env, ent_coef=ent_coef,
                                   batch_size=self.batch_size, gamma=self.gamma, actor_lr=actor_lr,
                                   critic_lr=critic_lr, risk_measure=risk_measure, seed=self.seed,
                                   policy_kwargs={ "truncation_lower": truncation_lower,
                                                   "truncation_upper": truncation_upper
                                                   }
                                   )
        self.model.learn(learning_steps, log_interval=1, train_frequency=1)
        result = test(self.test_env, self.model, self.num_eval_grid,
                      self.reward_dim, 'kriqn_ablation', self.num_eval_per_test)
        save_util.save(self.model, result)

    def kriqn_pe(self,
                 learning_steps: int,
                 critic_lr: float = 3e-4,
                 actor_lr: float = 3e-4,
                 risk: Literal['cvar', 'wang', 'triangle'] = 'cvar',
                 risk_param: float = 0.5,
                 index: Sequence[int] = (-1,),
                 ent_coef: float | Literal['auto'] = 'auto',
                 truncation_lower: int = 0,
                 truncation_upper: int = 1,
                 ):
        risk_measure = RiskMeasureGenerator(alpha=risk_param,
                                            reward_dim=self.reward_dim, index=index,
                                            name=risk)()
        if risk == 'neutral':
            risk_name = 'neutral'
        else:
            risk_name = f'{risk}_{risk_param:.2f}'
        save_util = SaveUtil(self.folder, f'MOSACAblation_pe_{risk_name}_{self.seed}', env_id=self.env_name)
        self.model = MOSACAblationPositionalEncoding(env=self.env, test_env=self.test_env, ent_coef=ent_coef,
                                                     batch_size=self.batch_size, gamma=self.gamma, actor_lr=actor_lr,
                                                     critic_lr=critic_lr, risk_measure=risk_measure, seed=self.seed,
                                                     policy_kwargs={ "truncation_lower": truncation_lower,
                                                                     "truncation_upper": truncation_upper
                                                                     }
                                                     )
        self.model.learn(learning_steps, log_interval=1, train_frequency=1)
        result = test(self.test_env, self.model, self.num_eval_grid,
                      self.reward_dim, 'kriqn_ablation_pe', self.num_eval_per_test)
        save_util.save(self.model, result)

    def kriqn_no_tqc(self,
                     learning_steps: int,
                     critic_lr: float = 3e-4,
                     actor_lr: float = 3e-4,
                     risk: Literal['cvar', 'wang', 'triangle'] = 'cvar',
                     risk_param: float = 0.5,
                     index: Sequence[int] = (-1,),
                     ent_coef: float | Literal['auto'] = 'auto',
                     ):
        risk_measure = RiskMeasureGenerator(alpha=risk_param,
                                            reward_dim=self.reward_dim, index=index,
                                            name=risk)()
        if risk == 'neutral':
            risk_name = 'neutral'
        else:
            risk_name = f'{risk}_{risk_param:.2f}'
        save_util = SaveUtil(self.folder, f'MOSACAblation_NO_TQC_{risk_name}_{self.seed}', env_id=self.env_name)
        self.model = MOSAC(env=self.env, test_env=self.test_env, ent_coef=ent_coef,
                           batch_size=self.batch_size, gamma=self.gamma, actor_lr=actor_lr,
                           critic_lr=critic_lr, risk_measure=risk_measure, seed=self.seed,
                           policy_kwargs={ "truncation_lower": 0, "truncation_upper": 0, "proj_tqc": 0 }
                           )

        self.model.learn(learning_steps, log_interval=1, train_frequency=1)
        result = test(self.test_env, self.model, self.num_eval_grid,
                      self.reward_dim, 'kriqn_no_tqc', self.num_eval_per_test)
        save_util.save(self.model, result)

    def kriqn_no_proj_tqc(self,
                          learning_steps: int,
                          critic_lr: float = 3e-4,
                          actor_lr: float = 3e-4,
                          risk: Literal['cvar', 'wang', 'triangle'] = 'cvar',
                          risk_param: float = 0.5,
                          index: Sequence[int] = (-1,),
                          ent_coef: float | Literal['auto'] = 'auto',

                          ):
        risk_measure = RiskMeasureGenerator(alpha=risk_param,
                                            reward_dim=self.reward_dim, index=index,
                                            name=risk)()
        if risk == 'neutral':
            risk_name = 'neutral'
        else:
            risk_name = f'{risk}_{risk_param:.2f}'
        save_util = SaveUtil(self.folder, f'MOSACAblation_NO_Proj_TQC_{risk_name}_{self.seed}', env_id=self.env_name)
        self.model = MOSAC(env=self.env, test_env=self.test_env, ent_coef=ent_coef,
                           batch_size=self.batch_size, gamma=self.gamma, actor_lr=actor_lr,
                           critic_lr=critic_lr, risk_measure=risk_measure, seed=self.seed,
                           policy_kwargs={ "truncation_lower": 2, "truncation_upper": 2, "proj_tqc": 0 }
                           )

        self.model.learn(learning_steps, log_interval=1, train_frequency=1)
        result = test(self.test_env, self.model, self.num_eval_grid,
                      self.reward_dim, 'kriqn_no_proj_tqc', self.num_eval_per_test)
        save_util.save(self.model, result)

    def kriqn_no_marginal_tqc(self,
                              learning_steps: int,
                              critic_lr: float = 3e-4,
                              actor_lr: float = 3e-4,
                              risk: Literal['cvar', 'wang', 'triangle'] = 'cvar',
                              risk_param: float = 0.5,
                              index: Sequence[int] = (-1,),
                              ent_coef: float | Literal['auto'] = 'auto',

                              ):
        risk_measure = RiskMeasureGenerator(alpha=risk_param,
                                            reward_dim=self.reward_dim, index=index,
                                            name=risk)()
        if risk == 'neutral':
            risk_name = 'neutral'
        else:
            risk_name = f'{risk}_{risk_param:.2f}'
        save_util = SaveUtil(self.folder, f'MOSACAblation_NO_Marginal_TQC_{risk_name}_{self.seed}', env_id=self.env_name)
        self.model = MOSAC(env=self.env, test_env=self.test_env, ent_coef=ent_coef,
                           batch_size=self.batch_size, gamma=self.gamma, actor_lr=actor_lr,
                           critic_lr=critic_lr, risk_measure=risk_measure, seed=self.seed,
                           policy_kwargs={ "truncation_lower": 0, "truncation_upper": 0,   }
                           )

        self.model.learn(learning_steps, log_interval=1, train_frequency=1)
        result = test(self.test_env, self.model, self.num_eval_grid,
                      self.reward_dim, 'mo_sac_no_marginal_tqc', self.num_eval_per_test)
        save_util.save(self.model, result)

    def kriqn_load_test(self,
                        learning_steps: int,
                        critic_lr: float = 3e-4,
                        actor_lr: float = 3e-4,
                        risk: Literal['cvar', 'wang', 'triangle'] = 'cvar',
                        risk_param: float = 0.5,
                        index: Sequence[int] = (-1,),
                        ent_coef: float | Literal['auto'] = 'auto',
                        ):
        risk_measure = RiskMeasureGenerator(alpha=risk_param,
                                            reward_dim=self.reward_dim, index=index,
                                            name=risk)()
        if risk == 'neutral':
            risk_name = 'neutral'
        else:
            risk_name = f'{risk}_{risk_param:.2f}'
        save_util = SaveUtil(self.folder, f'MOSACAblation_NO_TQC_{risk_name}_{self.seed}', env_id=self.env_name)
        self.model = MOSAC(env=self.env, test_env=self.test_env, ent_coef=ent_coef,
                           batch_size=self.batch_size, gamma=self.gamma, actor_lr=actor_lr,
                           critic_lr=critic_lr, risk_measure=risk_measure, seed=self.seed,
                           policy_kwargs={ "truncation_lower": 0, "truncation_upper": 0, "proj_tqc": 0 }
                           )
        self.model = save_util.load(self.model)
        result = test(self.test_env, self.model, self.num_eval_grid,
                      self.reward_dim, 'kriqn_no_tqc', self.num_eval_per_test)
        save_util.save_result(result)

    def kriqn_no_tqc_train_only(self,
                                learning_steps: int,
                                critic_lr: float = 3e-4,
                                actor_lr: float = 3e-4,
                                risk: Literal['cvar', 'wang', 'triangle'] = 'cvar',
                                risk_param: float = 0.5,
                                index: Sequence[int] = (-1,),
                                ent_coef: float | Literal['auto'] = 'auto',

                                ):
        risk_measure = RiskMeasureGenerator(alpha=risk_param,
                                            reward_dim=self.reward_dim, index=index,
                                            name=risk)()
        if risk == 'neutral':
            risk_name = 'neutral'
        else:
            risk_name = f'{risk}_{risk_param:.2f}'
        save_util = SaveUtil(self.folder, f'MOSACAblation_NO_TQC_{risk_name}_{self.seed}', env_id=self.env_name)
        self.model = MOSAC(env=self.env, test_env=self.test_env, ent_coef=ent_coef,
                           batch_size=self.batch_size, gamma=self.gamma, actor_lr=actor_lr,
                           critic_lr=critic_lr, risk_measure=risk_measure, seed=self.seed,
                           policy_kwargs={ "truncation_lower": 0, "truncation_upper": 0, "proj_tqc": 0 }
                           )
        self.model.learn(learning_steps, log_interval=1, train_frequency=1)
        save_util.save_model(self.model)

    def ewp(self,
            learning_steps: int,
            critic_lr: float = 3e-4,
            actor_lr: float = 3e-4,
            index: Sequence[int] = (-1,),
            ent_coef: float | Literal['auto'] = 'auto',
            truncation_lower: int = 0,
            truncation_upper: int = 1,
            ):

        save_util = SaveUtil(self.folder, f'EWP_{self.seed}', env_id=self.env_name)
        self.model = EWPSAC(env=self.env, test_env=self.test_env, ent_coef=ent_coef,
                            batch_size=self.batch_size, gamma=self.gamma, actor_lr=actor_lr,
                            critic_lr=critic_lr, seed=self.seed,
                            policy_kwargs={ "truncation_lower": truncation_lower, "truncation_upper": truncation_upper }
                            )
        self.model.learn(learning_steps, log_interval=1, train_frequency=1)
        result = test(self.test_env, self.model, self.num_eval_grid,
                      self.reward_dim, 'ewp', self.num_eval_per_test)
        save_util.save(self.model, result)


if __name__ == '__main__':
    fire.Fire(Main)
