# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/dqn/#dqn_ataripy
import os
import random
import time
from dataclasses import dataclass

import safety_gymnasium
import pandas as pd
import gymnasium as gym
import gym_trading_env
from gym_trading_env.downloader import download
import datetime
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import tyro
from stable_baselines3.common.atari_wrappers import (
    ClipRewardEnv,
    EpisodicLifeEnv,
    FireResetEnv,
    MaxAndSkipEnv,
    NoopResetEnv,
)
from stable_baselines3.common.buffers import ReplayBuffer
from torch.utils.tensorboard import SummaryWriter

from CybORG import CybORG
from CybORG.Agents import B_lineAgent, RedMeanderAgent
from ChallengeWrapper2 import ChallengeWrapper2
import inspect

@dataclass
class Args:
    exp_name: str = os.path.basename(__file__)[: -len(".py")]
    """the name of this experiment"""
    seed: int = 1
    """seed of the experiment"""
    torch_deterministic: bool = True
    """if toggled, `torch.backends.cudnn.deterministic=False`"""
    cuda: bool = True
    """if toggled, cuda will be enabled by default"""
    track: bool = False
    """if toggled, this experiment will be tracked with Weights and Biases"""
    wandb_project_name: str = "cleanRL"
    """the wandb's project name"""
    wandb_entity: str = None
    """the entity (team) of wandb's project"""
    capture_video: bool = False
    """whether to capture videos of the agent performances (check out `videos` folder)"""
    save_model: bool = False
    """whether to save model into the `runs/{run_name}` folder"""
    upload_model: bool = False
    """whether to upload the saved model to huggingface"""
    hf_entity: str = ""
    """the user or org name of the model repository from the Hugging Face Hub"""
    atari: bool = False
    safety: bool = False
    trade: bool = False

    # Algorithm specific arguments
    env_id: str = "BreakoutNoFrameskip-v4"
    """the id of the environment"""
    total_timesteps: int = 10000000
    """total timesteps of the experiments"""
    learning_rate: float = 1e-4
    """the learning rate of the optimizer"""
    num_envs: int = 1
    """the number of parallel game environments"""
    buffer_size: int = 500000
    """the replay memory buffer size"""
    gamma: float = 0.99
    """the discount factor gamma"""
    tau: float = 1.0
    """the target network update rate"""
    target_network_frequency: int = 1000
    """the timesteps it takes to update the target network"""
    batch_size: int = 32
    """the batch size of sample from the reply memory"""
    start_e: float = 1
    """the starting epsilon for exploration"""
    end_e: float = 0.01
    """the ending epsilon for exploration"""
    exploration_fraction: float = 0.10
    """the fraction of `total-timesteps` it takes from start-e to go end-e"""
    learning_starts: int = 80000
    """timestep to start learning"""
    train_frequency: int = 4
    """the frequency of training"""

def make_env(env_id, seed, idx, capture_video, run_name, atari):
    def thunk():    
        if atari:
            if capture_video and idx == 0:
                env = gym.make(env_id, render_mode="rgb_array")
                env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
            else:
                env = gym.make(env_id)
            env = gym.wrappers.RecordEpisodeStatistics(env)

            env = NoopResetEnv(env, noop_max=30)
            env = MaxAndSkipEnv(env, skip=4)
            env = EpisodicLifeEnv(env)
            if "FIRE" in env.unwrapped.get_action_meanings():
                env = FireResetEnv(env)
            env = ClipRewardEnv(env)
            env = gym.wrappers.ResizeObservation(env, (84, 84))
            env = gym.wrappers.GrayScaleObservation(env)
            env = gym.wrappers.FrameStack(env, 4)
        elif "Safe" in env_id:
            env = safety_gymnasium.make(env_id, render_mode=None)
        elif "CarRacing" in env_id:
            if capture_video and idx == 0:
                env = gym.make(env_id, render_mode="rgb_array", continuous = False)
                env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
            else:
                env = gym.make(env_id, continuous = False)
            env = gym.wrappers.RecordEpisodeStatistics(env)
            env = gym.wrappers.ResizeObservation(env, (84, 84))
            env = gym.wrappers.GrayScaleObservation(env)
            env = gym.wrappers.FrameStack(env, 4)
        elif "cage" in env_id:
            path = str(inspect.getfile(CybORG))
            path = path[:-10] + '/Shared/Scenarios/Scenario2.yaml'
            
            red_agent = B_lineAgent
            env = ChallengeWrapper2(env = CybORG(path, "sim", agents = {"Red": red_agent}), agent_name = "Blue", max_steps=100)
            env = gym.wrappers.RecordEpisodeStatistics(env)
        elif "Trading" in env_id:
            
            # download(exchange_names = ["bitfinex2"],
            #     symbols= ["BTC/USDT"],
            #     timeframe= "1h",
            #     dir = "data",
            #     since= datetime.datetime(year= 2020, month= 1, day=1),
            #     until = datetime.datetime(year = 2024, month = 1, day = 1),
            # )
            # Import your fresh data
            df = pd.read_pickle("./data/bitfinex2-BTCUSDT-1h.pkl")

            # df is a DataFrame with columns : "open", "high", "low", "close", "Volume USD"
            # Create the feature : ( close[t] - close[t-1] )/ close[t-1]
            df["feature_close"] = df["close"].pct_change()
            # Create the feature : open[t] / close[t]
            df["feature_open"] = df["open"]/df["close"]
            # Create the feature : high[t] / close[t]
            df["feature_high"] = df["high"]/df["close"]
            # Create the feature : low[t] / close[t]
            df["feature_low"] = df["low"]/df["close"]
            # Create the feature : volume[t] / max(*volume[t-7*24:t+1])
            df["feature_volume"] = df["volume"] / df["volume"].rolling(7*24).max()
            df.dropna(inplace= True) # Clean again !
            # Eatch step, the environment will return 5 inputs  : "feature_close", "feature_open", "feature_high", "feature_low", "feature_volume"

            env = gym.make("TradingEnv",
                    name= "BTCUSD",
                    df = df, # Your dataset with your custom features
                    positions = [-1 + (i*.2) for i in range(11)],
                    trading_fees = 0.01/100, # 0.01% per stock buy / sell (Binance fees)
                    borrow_interest_rate= 0.0003/100, # 0.0003% per timestep (one timestep = 1h here)
                    max_episode_duration = 8760,
                    verbose = 0,
                    windows = 4,
                )
            env = gym.wrappers.RecordEpisodeStatistics(env)
        env.action_space.seed(seed)
        return env

    return thunk

class Discretizer:
    def __init__(self, actions):
        self.actions = np.array(actions)
    def __len__(self):
        return len(self.actions)
    def __call__(self, x, dim = False):
        return self.actions[x]

# ALGO LOGIC: initialize agent here:
class QNetwork(nn.Module):
    def __init__(self, env, image, safety, trade, cage):
        super().__init__()
        if image:
            self.network = nn.Sequential(
            nn.Conv2d(4, 32, 8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, stride=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(3136, 512),
            nn.ReLU(),
            nn.Linear(512, env.single_action_space.n),
            )
            self.norm = 255
        elif safety:
            self.safety = True
            self.discretizer = Discretizer(torch.tensor([[0,0], [1, 0], [0, 1], [1, 1]]))
            #self.discretizer = Discretizer(torch.tensor([[0,0], [-1, 0], [1, 0], [0, -1], [0, 1], [-1, 1], [-1, -1], [1, -1], [1, 1]]))
            obs_space = envs.single_observation_space.shape[0]
            self.network = nn.Sequential(
                nn.Linear(obs_space, 64),
                nn.ReLU(),
                nn.Linear(64, 64),
                nn.ReLU(),
                nn.Linear(64, len(self.discretizer))
            )
            self.norm = 1
        elif trade:
            obs_space = envs.single_observation_space.shape[0]
            self.network = nn.Sequential(
                nn.Linear(obs_space, 64),
                nn.ReLU(),
                nn.Linear(64, 64),
                nn.ReLU(),
                nn.Linear(64, envs.single_action_space.n)
            )
            self.norm = 1
        elif cage:
            action_space = torch.tensor([133, 134, 135, 139,3, 4, 5, 9,16, 17, 18, 22,11, 12, 13, 14,141, 142, 143, 144,132,2,15, 24, 25, 26, 27])
            obs_space = envs.single_observation_space.shape[0]
            print(obs_space)
            self.network = nn.Sequential(
                nn.Linear(obs_space, 64),
                nn.ReLU(),
                nn.Linear(64, 64),
                nn.ReLU(),
                nn.Linear(64, len(action_space))
            )
            self.norm = 1
            self.act = ActionConverter(action_space)

    def forward(self, x):
        return self.network(x / self.norm)

class ActionConverter:
    def __init__(self, actions):
        self.actions = actions
    def __len__(self):
        return len(self.actions)
    def __call__(self, index):
        return self.actions[index]

def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
    slope = (end_e - start_e) / duration
    return max(slope * t + start_e, end_e)


if __name__ == "__main__":
    import stable_baselines3 as sb3

    if sb3.__version__ < "2.0":
        raise ValueError(
            """Ongoing migration: run the following command to install the new dependencies:
            poetry run pip install "stable_baselines3==2.0.0a1" "gymnasium[atari,accept-rom-license]==0.28.1"  "ale-py==0.8.1" 
            """
        )
    args = tyro.cli(Args)
    assert args.num_envs == 1, "vectorized envs are not supported at the moment"

    environs = ["cage"]

    for e in environs:
        args.env_id = e
        run_name = f"{args.env_id}__{args.exp_name}"
        if args.track:
            import wandb

            wandb.init(
                project=args.wandb_project_name,
                entity=args.wandb_entity,
                sync_tensorboard=True,
                config=vars(args),
                name=run_name,
                monitor_gym=True,
                save_code=True,
            )
        writer = SummaryWriter(f"runs/{run_name}")
        writer.add_text(
            "hyperparameters",
            "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
        )

        # TRY NOT TO MODIFY: seeding
        random.seed(args.seed)
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.backends.cudnn.deterministic = args.torch_deterministic

        device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")

        # env setup
        envs = gym.vector.SyncVectorEnv(
            [make_env(args.env_id, args.seed + i, i, args.capture_video, run_name, args.atari) for i in range(args.num_envs)]
        )
        #assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"
        print(envs.single_action_space)
        args.cage = ("cage" in args.env_id)

        q_network = QNetwork(envs, not (args.safety or args.trade or args.cage), args.safety, args.trade, args.cage).to(device)
        optimizer = optim.Adam(q_network.parameters(), lr=args.learning_rate)
        target_network = QNetwork(envs, not (args.safety or args.trade or args.cage), args.safety, args.trade, args.cage).to(device)
        target_network.load_state_dict(q_network.state_dict())

        rb = ReplayBuffer(
            args.buffer_size,
            envs.single_observation_space,
            (envs.single_action_space if not args.safety else gym.spaces.Discrete(len(q_network.discretizer))) ,
            device,
            optimize_memory_usage=True,
            handle_timeout_termination=False,
        )
        start_time = time.time()

        # TRY NOT TO MODIFY: start the game
        obs, _ = envs.reset(seed=args.seed)
        os.mkdir(f"dqn_models/{run_name}")
        for global_step in range(args.total_timesteps):
            if args.save_model and global_step%(args.total_timesteps // 10) == 0:
                model_path = f"dqn_models/{run_name}/{args.exp_name}.cleanrl_model"
                torch.save(q_network.state_dict(), model_path)
                print(f"model saved to {model_path}")


            # ALGO LOGIC: put action logic here
            epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_step)
            if random.random() < epsilon:
                if args.safety:
                    actions = np.random.randint(0, len(q_network.discretizer), len(envs.envs))
                elif args.cage:
                    actions = np.random.randint(0, len(q_network.act), len(envs.envs))
                else:
                    actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])
            else:
                q_values = q_network(torch.Tensor(obs).to(device))
                actions = torch.argmax(q_values, dim=1).cpu().numpy()

            # TRY NOT TO MODIFY: execute the game and log data.
            if args.safety:
                next_obs, rewards, terminations, truncations, infos = envs.step(q_network.discretizer(actions))
            elif args.cage:
                next_obs, rewards, terminations, truncations, infos = envs.step(q_network.act(actions).cpu().numpy())
            else:
                next_obs, rewards, terminations, truncations, infos = envs.step(actions)

            # TRY NOT TO MODIFY: record rewards for plotting purposes
            if "final_info" in infos:
                for info in infos["final_info"]:
                    if info and "episode" in info:
                        print(f"global_step={global_step}, episodic_return={info['episode']['r']}", end = "\r")
                        writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
                        writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)

            # TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
            real_next_obs = next_obs.copy()
            for idx, trunc in enumerate(truncations):
                if trunc:
                    real_next_obs[idx] = infos["final_observation"][idx]
            rb.add(obs, real_next_obs, actions, rewards, terminations, infos)

            # TRY NOT TO MODIFY: CRUCIAL step easy to overlook
            obs = next_obs

            # ALGO LOGIC: training.
            if global_step > args.learning_starts:
                if global_step % args.train_frequency == 0:
                    data = rb.sample(args.batch_size)
                    with torch.no_grad():
                        target_max, _ = target_network(data.next_observations.float()).max(dim=1)
                        td_target = data.rewards.flatten() + args.gamma * target_max * (1 - data.dones.flatten())
                    old_val = q_network(data.observations.float()).gather(1, data.actions).squeeze()
                    loss = F.mse_loss(td_target, old_val)

                    if global_step % 100 == 0:
                        writer.add_scalar("losses/td_loss", loss, global_step)
                        writer.add_scalar("losses/q_values", old_val.mean().item(), global_step)
                        writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)

                    # optimize the model
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                # update target network
                if global_step % args.target_network_frequency == 0:
                    for target_network_param, q_network_param in zip(target_network.parameters(), q_network.parameters()):
                        target_network_param.data.copy_(
                            args.tau * q_network_param.data + (1.0 - args.tau) * target_network_param.data
                        )

        if args.save_model:
            model_path = f"dqn_models/{run_name}/{args.exp_name}.cleanrl_model"
            torch.save(q_network.state_dict(), model_path)
            print(f"model saved to {model_path}")
            # from cleanrl_utils.evals.dqn_eval import evaluate

            # episodic_returns = evaluate(
            #     model_path,
            #     make_env,
            #     args.env_id,
            #     eval_episodes=10,
            #     run_name=f"{run_name}-eval",
            #     Model=QNetwork,
            #     device=device,
            #     epsilon=0.05,
            # )
            # for idx, episodic_return in enumerate(episodic_returns):
            #     writer.add_scalar("eval/episodic_return", episodic_return, idx)

            # if args.upload_model:
            #     from cleanrl_utils.huggingface import push_to_hub

            #     repo_name = f"{args.env_id}-{args.exp_name}-seed{args.seed}"
            #     repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name
            #     push_to_hub(args, episodic_returns, repo_id, "DQN", f"runs/{run_name}", f"videos/{run_name}-eval")

        envs.close()
        writer.close()
        wandb.finish()
