import gymnasium
from typing import Callable, Optional, Any, List, Union
import numpy as np
from gymnasium.core import ObsType, RenderFrame
from rl.utils.replay_buffer import ReplayBuffer
from rl.utils.base_wrapper import NormalizedGymnasiumBoxEnv
from rl.base_offline import BaseOffline
from datetime import datetime
from rl.utils.risk_utils import evaluate_risk_measure
import pandas as pd
from copy import deepcopy
import gym_mtsim
import os

forex_default_symbols = ('EURUSD',  'USDJPY', )
forex_fee_map = 1e-3


def forex_custom_env(hedge: bool = True,
                     trading_symbols=forex_default_symbols,
                     window_size: int = 15,
                     symbol_max_orders: int = 2,
                     fee=forex_fee_map,
                     log_reward: bool = False,
                     data_class: str = 'test',
                     ):
    path = os.path.dirname(os.path.abspath(__file__))
    data_path = os.path.join(path, f"symbols_forex_{data_class}.pkl")
    simulator = gym_mtsim.simulator.MtSimulator(symbols_filename=data_path, hedge=hedge, leverage=100,
                                                balance=int(5e+4), stop_out_level=0.1)

    mtsim_env = gym_mtsim.MtEnv(simulator,
                                trading_symbols=list(trading_symbols),
                                window_size=window_size,
                                symbol_max_orders=symbol_max_orders,
                                fee=fee,
                                randomize_initial_balance=False,
                                time_split=True,
                                preprocess=None,
                                done_if_equity_zero=True,
                                loss_cut=None,
                                log_reward=log_reward
                                )
    return mtsim_env


def stock_env(trading_symbols=('QQQ', 'SPY', 'IWM', 'IYY'),
              window_size: int = 10,
              symbol_max_orders: int = 4,
              fee=0.2,
              data_class: str = 'train',
              ):
    path = os.path.dirname(os.path.abspath(__file__))
    data_path = os.path.join(path, f"symbols_stocks_{data_class}.pkl")
    simulator = gym_mtsim.simulator.MtSimulator(symbols_filename=data_path, hedge=False, leverage=10,
                                                balance=int(5e+4), stop_out_level=0.1)
    mtsim_env = gym_mtsim.MtEnv(simulator,
                                trading_symbols=list(trading_symbols),
                                window_size=window_size,
                                symbol_max_orders=symbol_max_orders,
                                fee=fee,
                                preprocess=None,
                                randomize_initial_balance=False,
                                time_split=True,
                                done_if_equity_zero=True,
                                loss_cut=None,
                                )
    return mtsim_env


class RandomLengthEnv(gym_mtsim.MtEnv):
    def __init__(
            self,
            original_simulator: gym_mtsim.MtSimulator,
            trading_symbols: List[str],
            window_size: int,
            time_points: Optional[List[datetime]] = None,
            hold_threshold: float = 0.5,
            close_threshold: float = 0.5,
            fee: Union[float, Callable[[str], float]] = 0.0005,
            symbol_max_orders: int = 1,
            multiprocessing_processes: Optional[int] = None,
            render_mode: Optional[str] = None,
            minimal_length: int = 15,
            seed: int = 42
    ) -> None:
        super().__init__(original_simulator, trading_symbols, window_size,
                         time_points=time_points, hold_threshold=hold_threshold,
                         close_threshold=close_threshold, fee=fee,
                         symbol_max_orders=symbol_max_orders,
                         multiprocessing_processes=multiprocessing_processes,
                         render_mode=render_mode
                         )
        self.np_rng = np.random.default_rng(seed)
        self.minimal_length = minimal_length

    def reset(self, seed=None, options=None):
        gymnasium.Env.reset(self, seed=seed, options=options)
        self.simulator = deepcopy(self.original_simulator)
        self.simulator.equity = deepcopy(self.simulator.balance)
        self._truncated = False
        self._current_tick = self._start_tick
        self.simulator.current_time = self.time_points[self._current_tick]
        self.history = [self._create_info()]
        observation = self._get_observation()
        info = self._create_info()
        return observation, info


class MTSimWrapper(gymnasium.Env):
    def __init__(self,
                 env_id: str,
                 observation_preprocessor: Optional[Callable] = None,
                 preprocessed_observation_space: Optional[gymnasium.Space] = None,
                 seed: int = 42,
                 data_type: str = 'train',
                 log_reward: bool = False,
                 ):
        env_id = env_id.replace('_', '-')

        env_type, hedge, version = env_id.split('-')
        hedge = (hedge == 'hedge')
        if data_type == 'test':
            if env_type == "forex":
                self.wrapped = forex_custom_env(hedge=hedge, data_class='test', )
            elif env_type == 'stock':
                self.wrapped = stock_env(data_class='test')
            else:
                raise NotImplementedError("Collect data ")
        else:
            if env_type == 'forex':
                self.wrapped = forex_custom_env(hedge=hedge, data_class='train', log_reward=log_reward)
            elif env_type == 'stock':
                self.wrapped = stock_env(data_class='train')
            else:
                raise NotImplementedError
        '''
        elif data_type == 'valid':
            if env_type == 'forex':
                self.wrapped = forex_custom_env(hedge=hedge, data_class='valid')
            elif env_type == 'stocks':
                self.wrapped = stock_env(data_class='valid')
            else:
                raise NotImplementedError
        '''

        self.action_space = gymnasium.spaces.box.Box(-1, 1, shape=self.wrapped.action_space.shape)
        self.observation_key = list(self.wrapped.observation_space.keys())
        self.observation_key.sort()
        if observation_preprocessor is None:
            self.observation_preprocessor: Callable = self.default_preprocessor
        else:
            self.observation_preprocessor: Callable = observation_preprocessor

        if preprocessed_observation_space is None:
            placeholder = self.observation_preprocessor(self.wrapped.observation_space.sample())
            self.observation_space = gymnasium.spaces.box.Box(low=-np.inf, high=np.inf, shape=placeholder.shape)
        else:
            self.observation_space = preprocessed_observation_space
        self.np_rng = np.random.default_rng(seed)
        self.data_collection = (data_type == 'data_collection')
        self.order_statistics = 0
        self.pure_return = 0.
        self.log_reward = log_reward

    def default_preprocessor(self, obs):
        feature = obs['features']
        # fft = np.fft.fft(feature, axis=0, norm='ortho')
        obs['feature'] = feature # np.concatenate([fft.real, fft.imag, feature], axis=0)
        flatten = np.concatenate([obs[k].flatten() for k in self.observation_key], axis=-1)
        new_observation = flatten.copy()
        return np.arcsinh(new_observation)

    def reset(
            self,
            *,
            seed: int | None = None,
            options: dict[str, Any] | None = None) -> tuple[ObsType, dict[str, Any]]:
        observation, info = self.wrapped.reset(seed=seed, options=options)

        if self.data_collection:
            print("ORDER STATISTICS", self.order_statistics, "PURE RETURN", self.pure_return, "%")

        self.order_statistics = 0
        self.pure_return = 0
        return self.observation_preprocessor(observation), info

    def step(self, action: np.ndarray):
        prev = self.wrapped.simulator.equity
        obs, reward, done, timeout, info = self.wrapped.step(action)
        done = done or timeout
        after = self.wrapped.simulator.equity
        orders = info['orders']

        self.pure_return += 100 * (after - prev) / self.wrapped.initial_balance

        if self.data_collection:
            orders: dict = info['orders']
            order_reward = 0

            for k in orders.keys():
                if orders[k]['order_id'] != None and (orders[k]['error'] == ''):
                    order_reward += 1
                    self.order_statistics += 1
            if order_reward > 0:
                reward += order_reward
            else:
                reward -= 1

            # reward = reward
        if self.log_reward:
            reward = reward * 100
        return self.observation_preprocessor(obs), reward, done, timeout, info

    def render(self) -> RenderFrame | list[RenderFrame] | None:
        return self.wrapped.render()


class MTSimPreprocessor(object):
    def __init__(self,
                 env_id: str,
                 path: str,
                 normalize_obs: bool = True,
                 normalize_reward: bool = False,
                 timeout_terminal=True,
                 seed: int = 0,
                 ):
        self.gymnasium_env = MTSimWrapper(env_id, seed=seed, data_type='train')
        self.valid_env = MTSimWrapper(env_id, data_type='valid', seed=seed)
        self.test_env = MTSimWrapper(env_id, data_type='test', seed=seed)

        self.normalize_reward = normalize_reward
        self.d4rl_gym_env = self.gymnasium_env.wrapped
        self.buffer = ReplayBuffer.from_npz(path,
                                            normalize_reward=normalize_reward,
                                            terminal_if_timeout=timeout_terminal
                                            )

        self.obs_mean = self.buffer.observations.mean(axis=0, keepdims=True)
        self.obs_std = self.buffer.observations.std(axis=0, keepdims=True).clip(1e-12)
        self.np_rng = np.random.default_rng(seed)
        self.seed = seed
        if normalize_obs:
            def normalize(x, mean, std):
                return (x - mean) / std.clip(1e-12, )

            self.buffer.observations = normalize(self.buffer.observations, self.obs_mean, self.obs_std)
            self.buffer.next_observations = normalize(self.buffer.next_observations, self.obs_mean, self.obs_std)
            self.gymnasium_env = NormalizedGymnasiumBoxEnv(self.gymnasium_env,
                                                           obs_mean=self.obs_mean, obs_std=self.obs_std)
            self.test_env = NormalizedGymnasiumBoxEnv(self.test_env,
                                                      obs_mean=self.obs_mean, obs_std=self.obs_std
                                                      )
            self.valid_env = NormalizedGymnasiumBoxEnv(self.valid_env,
                                                       obs_mean=self.obs_mean, obs_std=self.obs_std)
            print("ENV HAS BEEN LOADED!")

    @property
    def env(self):
        return self.gymnasium_env

    def get_replay_buffer(self):
        return self.buffer

    def evaluate(self, model: Any, n_eval: int, env: gymnasium.Env, seed: Optional[int] = None):
        scores = []
        if seed is None:
            rng = self.np_rng
        else:
            rng = np.random.default_rng(seed)

        for _ in range(n_eval):
            _seed = int(rng.integers(0, 2 ** 32 - 1))
            obs, _ = env.reset(seed=_seed)
            done = False
            score = 0
            while not done:
                action = model.predict(obs)
                obs, reward, done, timeout, info = env.step(action)
                done = done or timeout
                score += reward
            scores.append(score)
        scores.sort()
        return np.asarray(scores, dtype=np.float64)

    def train_pipeline(self,
                       model: BaseOffline,
                       model_save_path: str,
                       num_epoch: int, len_epoch: int,
                       risk_type: str, risk_eta: float,
                       eval_interval: int = 5,
                       valid_eval: int = 100,
                       test_eval: int = 1000,
                       ):
        valid_scores = []
        test_scores = []

        def validate():
            valid_score_array = self.evaluate(model, n_eval=valid_eval,
                                              env=self.valid_env, seed=self.seed)
            valid_score = evaluate_risk_measure(valid_score_array, risk_type, risk_eta)
            # important! Fixing seed for each evaluation so that the consistency can be guaranteed
            test_scores = self.evaluate(model, n_eval=test_eval,
                                        env=self.test_env, seed=self.seed)
            test_mean, test_std = np.mean(test_scores), np.std(test_scores)
            test_risk = evaluate_risk_measure(test_scores, risk_type, risk_eta)
            return valid_score, (test_mean, test_std, test_risk)

        for e in range(num_epoch):
            model.epoch_learn(len_epoch)
            if e % eval_interval == 0:
                valid_risk, (test_mean, test_std, test_risk) = validate()

                print(f"EPOCH: {e}")
                print(f"VALID SCORE  {valid_risk:.4f}, "
                      f"TEST SCORE {test_mean:.4f}+/-{test_std:.4f}, "
                      f"{risk_type}@{risk_eta:.2f}: {test_risk:.4f}")
                valid_scores.append(valid_risk)
                test_scores.append((test_mean, test_std, test_risk))

        valid_risk, (test_mean, test_std, test_risk) = validate()
        valid_scores.append(valid_risk)
        test_scores.append((test_mean, test_std, test_risk))
        valid_scores = np.asarray(valid_scores)

        index = np.argmax(valid_scores)
        test_at = test_scores[index]
        mu, sigma, risk = test_at
        array_test = np.asarray(test_scores)

        df = {"valid": valid_scores,
              "test_mean": array_test[:, 0],
              "test_std": array_test[:, 1],
              "test_risk": array_test[:, 2],
              "seed": np.ones_like(valid_scores, dtype=np.int32) * self.seed,
              }
        print(f"VALID SCORE {valid_scores[index]:.4f}%, TEST SCORE {mu:.4f}+/-{sigma:.4f}%, "
              f"{risk_type}@{risk_eta:.2f}: {risk:.4f}%")
        model.save(model_save_path)
        return pd.DataFrame.from_dict(df)


if __name__ == "__main__":
    env = forex_custom_env(hedge=True, data_class='train')
    print(env.original_simulator.symbols_data['USDJPY'].tail(10))
