import warnings

warnings.filterwarnings("ignore", category=UserWarning, module="pkg_resources")

import os
import numpy as np

from risk_morl.utils.misc import fix_seed
import fire
import pandas as pd
from gymnasium.wrappers import RescaleAction, TimeLimit
import mo_gymnasium
from risk_morl.mo_sac_net import MOSAC, EWPSAC, MOSACAblation, MarginalMOSAC
from copy import deepcopy
from risk_morl.utils.env_util import MODummyVecEnv, reward_dim, MOSubprocVecEnv
from risk_morl.utils.mo_utils import constrained_simplex_test, constrained_simplex_finance_test, \
    vec_constrained_simplex_test
from risk_morl.utils.risk_measures import RiskMeasureGenerator
from typing import Literal, Sequence
from portfolio.port_mo_env import StockTradingMOEnv
from functools import partial

PRIME_1 = 8610829
PRIME_2 = 8572451
PRIME_3 = 1497233


class SaveUtil(object):
    def __init__(self, folder, path, env_id=None):
        self.base_folder = folder
        self.env_id = env_id
        if self.env_id:
            self.folder = os.path.join(self.base_folder, self.env_id)
        else:
            self.folder = self.base_folder
        self.path = path

    def make_folder(self):
        os.makedirs(self.base_folder, exist_ok=True)
        if self.env_id:
            os.makedirs(self.folder, exist_ok=True)

    def save(self, model, result_valid: pd.DataFrame, result_test: pd.DataFrame, details: pd.DataFrame):
        self.make_folder()
        model.save(os.path.join(self.folder, self.path) + ".pth")
        result_valid.to_csv(os.path.join(self.folder, self.path + "result_valid.csv"), index=False)
        result_test.to_csv(os.path.join(self.folder, self.path + "result_test.csv"), index=False)
        details.to_csv(os.path.join(self.folder, self.path + "test_info.csv"), index=False)

    def load(self, model):
        return model.load(os.path.join(self.folder, self.path) + ".pth")

    def save_only_test(self, result_test: pd.DataFrame, details: pd.DataFrame):
        result_test.to_csv(os.path.join(self.folder, self.path + "result_retest.csv"), index=False)
        details.to_csv(os.path.join(self.folder, self.path + "retest_info.csv"), index=False)

    def save_only_train(self, result_train: pd.DataFrame):
        result_train.to_csv(os.path.join(self.folder, self.path + "result_train.csv"), index=False)


def load_env(mode: Literal['train', 'valid', 'test'] = 'train', seed: int = 42):
    path = { "train": 'train_df.csv', 'test': 'valid_df.csv' }
    options = { "train": { "randomize_day": True,
                           "bidding": "adv_uniform",
                           "stop_loss_calculation": 'low',
                           "seed": seed,
                           },

                "test": { "randomize_day": False,
                          "bidding": "uniform",
                          "stop_loss_calculation": 'close',
                          "seed": (seed * PRIME_2) % int(2 ** 31)
                          }
                }
    df = pd.read_csv(path[mode])
    df = df.set_index(df.columns[0])
    return RescaleAction(StockTradingMOEnv(df, **(options[mode])), min_action=-1, max_action=1)


def shuffle(seed, i):
    rng = np.random.default_rng(seed + PRIME_1 + i)
    seed = int((rng.integers(0, 2 ** 31, size=(i + 30,))[(i * PRIME_2) % (i + 3)]).item())

    return seed


class Main(object):
    model: MOSAC

    def __init__(self,
                 batch_size: int = 256,
                 gamma: float = 0.99,
                 num_env: int = 1,
                 folder: str = 'results',
                 num_eval_grid: int = 6,
                 num_eval_per_test: int = 100,
                 vec_test: int = 25,
                 *,
                 seed: int = 42
                 ):
        fix_seed(seed)
        self.env_name = "finance"
        env_fns = []
        for i in range(1, num_env + 1):
            worker_seed = shuffle(seed + i, i)
            env_fns.append(deepcopy(lambda worker_seed=worker_seed: load_env('train', worker_seed)))
        self.env = MOSubprocVecEnv(env_fns)
        setattr(self.env, "reward_dim", 3)
        self.vec_test = vec_test
        env_fns = []
        for i in range(1, self.vec_test + 1):
            worker_seed = shuffle(seed + i, i)
            env_fns.append(deepcopy(lambda worker_seed=worker_seed: load_env('train', worker_seed)))
        self.valid_env = MOSubprocVecEnv(env_fns)
        self.test_env = load_env('test', seed=seed)

        self.reward_dim = reward_dim(self.test_env)
        self.batch_size = batch_size
        self.gamma = gamma
        self.folder = folder
        self.seed = seed
        self.num_eval_grid = num_eval_grid
        self.num_eval_per_test = num_eval_per_test

    @staticmethod
    def wrap_env(env):
        return RescaleAction(env, min_action=-1., max_action=1.)

    def kirqn(self,
              learning_steps: int = 35000,
              critic_lr: float = 3e-4,
              actor_lr: float = 3e-4,
              risk: Literal['cvar', 'wang', 'triangle'] = 'cvar',
              risk_param: float = 0.5,
              index: Sequence[int] = (-1,),
              ent_coef: float | Literal['auto'] = 'auto',
              target_entropy: float | Literal['auto'] = 'auto',
              truncation_lower: int = 0,
              truncation_upper: int = 1,
              test_interval: int = int(1e+5),
              train_frequency: int = 1,
              opt_class: Literal['sgd', 'adam', 'adabelief'] = 'adam',

              ):
        risk_measure = RiskMeasureGenerator(alpha=risk_param,
                                            reward_dim=self.reward_dim, index=index,
                                            name=risk)()
        if risk == 'neutral':
            risk_name = 'neutral'
        else:
            risk_name = f'{risk}_{risk_param:.2f}'
        actor_marginal_risk = False
        comonotone = False



        save_util = SaveUtil(self.folder, f'MOSAC_{risk_name}_{self.seed}', env_id=self.env_name)
        self.model = MOSAC(env=self.env, test_env=self.test_env, ent_coef=ent_coef,
                           batch_size=self.batch_size, gamma=self.gamma, actor_lr=actor_lr,
                           critic_lr=critic_lr, risk_measure=risk_measure, seed=self.seed,
                           target_entropy=target_entropy, comonotone=comonotone,
                           actor_marginal_risk=actor_marginal_risk,
                           policy_kwargs={ "truncation_lower": truncation_lower, "truncation_upper": truncation_upper,
                                           "opt_class": opt_class
                                           }

                           )
        self.model.learn(learning_steps, log_interval=1, train_frequency=train_frequency,
                         test_env=self.test_env, test_interval=test_interval, )
        result_test, details = constrained_simplex_finance_test(self.test_env, self.model, index=2, min_val=0.05,
                                                                max_val=0.95,
                                                                num_grids=self.num_eval_grid,
                                                                reward_dim=self.reward_dim, name='kirqn',
                                                                num_execution=self.num_eval_per_test)

        result_valid = vec_constrained_simplex_test(self.valid_env, self.model, index=2, min_val=0.05, max_val=0.95,
                                                    num_grids=self.num_eval_grid,
                                                    reward_dim=self.reward_dim, name='kirqn',
                                                    num_execution=self.num_eval_per_test, n_env=self.vec_test)

        save_util.save(self.model, result_valid=result_valid, result_test=result_test, details=details)

    def marginal(self,
                 learning_steps: int,
                 critic_lr: float = 3e-4,
                 actor_lr: float = 3e-4,
                 risk: Literal['cvar', 'wang', 'triangle'] = 'cvar',
                 risk_param: float = 0.5,
                 index: Sequence[int] = (-1,),
                 ent_coef: float | Literal['auto'] = 'auto',
                 truncation_lower: int = 0,
                 truncation_upper: int = 1,
                 ):
        risk_measure = RiskMeasureGenerator(alpha=risk_param,
                                            reward_dim=self.reward_dim, index=index,
                                            name=risk)()
        if risk == 'neutral':
            risk_name = 'neutral'
        else:
            risk_name = f'{risk}_{risk_param:.2f}'
        save_util = SaveUtil(self.folder, f'MarginalIQN_{risk_name}_{self.seed}', env_id=self.env_name)
        self.model = MarginalMOSAC(env=self.env, test_env=self.test_env, ent_coef=ent_coef,
                                   batch_size=self.batch_size, gamma=self.gamma, actor_lr=actor_lr,
                                   critic_lr=critic_lr, risk_measure=risk_measure, seed=self.seed,
                                   policy_kwargs={ "truncation_lower": truncation_lower,
                                                   "truncation_upper": truncation_upper
                                                   }
                                   )
        self.model.learn(learning_steps, log_interval=1, train_frequency=1)
        result_test, details = constrained_simplex_finance_test(self.test_env, self.model, index=2, min_val=0.05,
                                                                max_val=0.95,
                                                                num_grids=self.num_eval_grid,
                                                                reward_dim=self.reward_dim, name='mo_sac',
                                                                num_execution=self.num_eval_per_test)

        result_valid = vec_constrained_simplex_test(self.valid_env, self.model, index=2, min_val=0.05, max_val=0.95,
                                                    num_grids=self.num_eval_grid,
                                                    reward_dim=self.reward_dim, name='mo_sac',
                                                    num_execution=self.num_eval_per_test, n_env=self.vec_test)

        save_util.save(self.model, result_valid=result_valid, result_test=result_test, details=details)

    def load_test_kriqn(self,
                        critic_lr: float = 3e-4,
                        actor_lr: float = 3e-4,
                        risk: Literal['cvar', 'wang', 'triangle'] = 'cvar',
                        risk_param: float = 0.5,
                        index: Sequence[int] = (-1,),
                        ent_coef: float | Literal['auto'] = 'auto',
                        truncation_lower: int = 0,
                        truncation_upper: int = 1,
                        ):

        risk_measure = RiskMeasureGenerator(alpha=risk_param,
                                            reward_dim=self.reward_dim, index=index,
                                            name=risk)()
        if risk == 'neutral':
            risk_name = 'neutral'
        else:
            risk_name = f'{risk}_{risk_param:.2f}'



        save_util = SaveUtil(self.folder, f'MOSAC_{risk_name}_{self.seed}', env_id=self.env_name)
        self.model = MOSAC(env=self.env, test_env=self.test_env, ent_coef=ent_coef,
                           batch_size=self.batch_size, gamma=self.gamma, actor_lr=actor_lr,
                           critic_lr=critic_lr, risk_measure=risk_measure, seed=self.seed,
                           policy_kwargs={ "truncation_lower": truncation_lower, "truncation_upper": truncation_upper }
                           )
        save_util.load(self.model)

        result_valid = vec_constrained_simplex_test(self.valid_env, self.model, index=2, min_val=0.05, max_val=0.95,
                                                    num_grids=self.num_eval_grid,
                                                    reward_dim=self.reward_dim, name='mo_sac',
                                                    num_execution=self.num_eval_per_test, n_env=self.vec_test)

        save_util.save_only_train(result_valid)

    def ewp(self,
            learning_steps: int,
            critic_lr: float = 3e-4,
            actor_lr: float = 3e-4,
            ent_coef: float | Literal['auto'] = 'auto',
            train_frequency: int = 1,
            truncation_lower: int = 0,
            truncation_upper: int = 1,
            test_interval: int = int(1e+5),
            ):

        save_util = SaveUtil(self.folder, f'EWP_{self.seed}', env_id=self.env_name)
        self.model = EWPSAC(env=self.env, test_env=self.test_env, ent_coef=ent_coef,
                            batch_size=self.batch_size, gamma=self.gamma, actor_lr=actor_lr,
                            critic_lr=critic_lr, seed=self.seed,
                            policy_kwargs={ "truncation_lower": truncation_lower, "truncation_upper": truncation_upper }
                            )
        self.model.learn(learning_steps, log_interval=1, train_frequency=train_frequency,
                         test_env=self.test_env, test_interval=test_interval, )
        result_test, details = constrained_simplex_finance_test(self.test_env, self.model, index=2, min_val=0.05,
                                                                max_val=0.95,
                                                                num_grids=self.num_eval_grid,
                                                                reward_dim=self.reward_dim, name='ewp',
                                                                num_execution=self.num_eval_per_test)

        result_valid = vec_constrained_simplex_test(self.valid_env, self.model, index=2, min_val=0.05, max_val=0.95,
                                                    num_grids=self.num_eval_grid,
                                                    reward_dim=self.reward_dim, name='ewp',
                                                    num_execution=self.num_eval_per_test, n_env=self.vec_test)

        save_util.save(self.model, result_valid=result_valid, result_test=result_test, details=details)

    def load_test_ewp(self,
                      critic_lr: float = 3e-4,
                      actor_lr: float = 3e-4,

                      ent_coef: float | Literal['auto'] = 'auto',
                      truncation_lower: int = 0,
                      truncation_upper: int = 1,
                      ):
        save_util = SaveUtil(self.folder, f'EWP_{self.seed}', env_id=self.env_name)
        self.model = EWPSAC(env=self.env, test_env=self.test_env, ent_coef=ent_coef,
                            batch_size=self.batch_size, gamma=self.gamma, actor_lr=actor_lr,
                            critic_lr=critic_lr, seed=self.seed,
                            policy_kwargs={ "truncation_lower": truncation_lower, "truncation_upper": truncation_upper }
                            )
        save_util.load(self.model)
        result_valid = vec_constrained_simplex_test(self.valid_env, self.model, index=2, min_val=0.05, max_val=0.95,
                                                    num_grids=self.num_eval_grid,
                                                    reward_dim=self.reward_dim, name='ewp',
                                                    num_execution=self.num_eval_per_test, n_env=self.vec_test)

        save_util.save_only_train(result_valid)


if __name__ == '__main__':

    fire.Fire(Main)
