import sys

sys.dont_write_bytecode = True

import os
import gym
import gymnasium
from vizdoom import gymnasium_wrapper
import math
import copy
import random
import datetime
import numpy as np
import cv2
from time import sleep
import wandb
import pickle

import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.lines import Line2D
import matplotlib
import pickle
import pandas as pd
from collections import deque, namedtuple

import torch
import torch.nn as nn
from torch.optim.lr_scheduler import CosineAnnealingLR, OneCycleLR, StepLR, CyclicLR
from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F
from transformers import get_cosine_schedule_with_warmup

from components.drawing import Arrow3D

font = {"weight": "bold", "size": 11}

matplotlib.rc("font", **font)


def set_seed(seed):
    """Sets the seed of the entire notebook so results are the same every time we run.
    This is for REPRODUCIBILITY."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Set a fixed value for the hash seed
    os.environ["PYTHONHASHSEED"] = str(seed)

    print("Global seeds set to", seed)


Experience = namedtuple(
    "Experience",
    field_names=["state", "action", "reward", "next_state", "done", "state_hash"],
)


class ReplayBuffer:
    def __init__(self, buffer_size, batch_size, seed, memory=None):
        if memory == None:
            self.memory = deque(maxlen=buffer_size)
        else:
            self.memory = memory
        self.batch_size = batch_size

    def add(self, state, action, reward, next_state, done, state_hash):
        """Add a new experience to memory."""
        e = Experience(state, action, reward, next_state, done, state_hash)
        self.memory.append(e)

    def sample(self, batch_size=None):
        if batch_size is None:
            experiences = random.sample(self.memory, k=self.batch_size)
        else:
            experiences = random.sample(self.memory, k=batch_size)
        states = torch.from_numpy(
            np.vstack([e.state[None, :] for e in experiences if e is not None])
        ).float()
        actions = torch.from_numpy(
            np.vstack([e.action for e in experiences if e is not None])
        ).long()
        rewards = torch.from_numpy(
            np.vstack([e.reward for e in experiences if e is not None])
        ).float()
        next_states = torch.from_numpy(
            np.vstack([e.next_state[None, :] for e in experiences if e is not None])
        ).float()
        dones = torch.from_numpy(
            np.vstack([e.done for e in experiences if e is not None]).astype(np.uint8)
        ).float()

        return (states, actions, rewards, next_states, dones)

    def __len__(self):
        return len(self.memory)

    def save(self, file_path):
        """Save the replay buffer to a file."""
        with open(file_path, "wb") as f:
            pickle.dump(self.memory, f)
        print(f"Replay buffer saved to {file_path}.")

    def load(self, file_path):
        """Load the replay buffer from a file."""
        with open(file_path, "rb") as f:
            self.memory = pickle.load(f)
        print(f"Replay buffer loaded from {file_path}.")
        print(f"Replay buffer contains {len(self.memory)} experiences.")


def split_buffer(buffer, split_ratio=0.8):
    """Split the replay buffer into train and test sets."""
    if not (0 <= split_ratio <= 1):
        raise ValueError("train_percentage must be between 0 and 1.")

    # Shuffle the memory to ensure randomness
    shuffled_memory = list(buffer.memory)
    random.shuffle(shuffled_memory)

    # Split the memory
    split_idx_train = int(len(shuffled_memory) * split_ratio)
    split_idx_test = int(len(shuffled_memory) * 0.8)
    train_set = shuffled_memory[:split_idx_train]
    test_set = shuffled_memory[split_idx_test:]
    train_buffer = ReplayBuffer(
        buffer_size=len(train_set),
        batch_size=buffer.batch_size,
        memory=train_set,
        seed=33,
    )
    test_buffer = ReplayBuffer(
        buffer_size=len(test_set),
        batch_size=buffer.batch_size,
        memory=test_set,
        seed=33,
    )

    return train_buffer, test_buffer


class RandomAgent:
    def __init__(self, env):
        self.env = env

    def act(self, *_args, **_kwargs):
        num_action = self.env.get_num_actions()
        return np.random.choice([0, 1, 2, 3])


class QNetwork(nn.Module):
    def __init__(self, env):
        super().__init__()
        self.network = nn.Sequential(
            nn.Conv2d(3, 32, 8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, stride=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, 4),
        )

    def forward(self, x):
        return self.network(x)


class MBAgent:
    def __init__(
        self,
        env: gym.Env,
        batch_size: int,
        learn_every: int,
        device: str = "cpu",
        debug: bool = False,
        buffer_size=10000,
        seed=2022,
        wandb_run=None,
        gamma=0.99,
        tau=1,
        target_network_frequency=500,
        start_e=1,
        end_e=0.05,
        exploration_fraction=0.5,
        train_frequency=10,
        lr_dqn=2.5e-4,
        rb_path=None,
    ):
        self.env = env
        self.device = device
        self.batch_size = batch_size
        self.learn_every = learn_every
        self.debug = debug
        self.replays = ReplayBuffer(
            buffer_size=buffer_size, batch_size=batch_size, seed=seed
        )
        if rb_path is not None:
            self.replays.load(rb_path)
        self.seed = seed

        self.wandb_run = wandb_run

        # DQN parameters
        self.gamma = gamma
        self.tau = tau
        self.target_network_frequency = target_network_frequency
        self.start_e = start_e
        self.end_e = end_e
        self.exploration_fraction = exploration_fraction
        self.train_frequency = train_frequency
        self.lr_dqn = lr_dqn

        self.q_network = QNetwork(env).to(device)
        self.optimizer_dqn = torch.optim.Adam(self.q_network.parameters(), lr=lr_dqn)
        self.target_network = QNetwork(env).to(device)
        self.target_network.load_state_dict(self.q_network.state_dict())

    def linear_schedule(self, num_episodes=100, ep_len=2500):
        slope = (self.end_e - self.start_e) / (
            self.exploration_fraction * num_episodes * ep_len
        )
        return max(slope * self.t + self.start_e, self.end_e)

    def learn_dqn(self, step):
        states, actions, rewards, next_states, dones = self.replays.sample(
            batch_size=self.batch_size
        )
        states = states.to(self.device)
        next_states = next_states.to(self.device)
        actions = actions.to(self.device)
        rewards = rewards.to(self.device)
        dones = dones.to(self.device)

        with torch.no_grad():
            _, idx_target = self.q_network(states).max(dim=1)
            target_max = (
                self.target_network(next_states)
                .gather(1, idx_target.unsqueeze(1))
                .squeeze()
            )
            td_target = rewards.detach().flatten() + self.gamma * target_max * (
                1 - dones.flatten()
            )
        old_val = self.q_network(states.detach()).gather(1, actions).squeeze()
        loss = F.mse_loss(td_target, old_val)

        wandb.log(
            {
                "losses/td_loss": loss.item(),
                "losses/q_values": old_val.mean().item(),
                "dqn_step": step,
            }
        )

        # optimize the model
        self.optimizer_dqn.zero_grad()
        loss.backward()
        self.optimizer_dqn.step()

        # update target network
        if self.t % self.target_network_frequency == 0:
            for target_network_param, q_network_param in zip(
                self.target_network.parameters(), self.q_network.parameters()
            ):
                target_network_param.data.copy_(
                    self.tau * q_network_param.data
                    + (1.0 - self.tau) * target_network_param.data
                )

    def evaluate_dqn(self):
        cum_reward = 0
        env = VizdoomSingleRoom(render_mode=None)
        state, state_hash = env.reset()
        for _ in range(2500):
            q_values = self.q_network(
                torch.from_numpy(state[None]).float().to(self.device)
            )
            action = torch.argmax(q_values, dim=1).cpu().numpy()
            next_state, reward, done, next_state_hash = env.step(action)
            cum_reward += reward

            state = next_state
            state_hash = next_state_hash
            if done:
                break
        env.close()
        return cum_reward

    def train(self, global_step=300000):
        for step in range(global_step):
            if step % 1000 == 0:
                print("training dqn step: ", step)
                cum_reward = self.evaluate_dqn()

                wandb.log({"metrics/episode_reward": cum_reward, "dqn_step": step})
            self.learn_dqn(step)


class VizdoomSingleRoom(gymnasium.Env):
    def __init__(self, render_mode="human", num_stack=4):
        self.env = gymnasium.make("VizdoomSingleRoom-v0", render_mode=render_mode)
        self.num_stack = num_stack
        # self.env = gymnasium.wrappers.TransformObservation(self.env, lambda obs: obs['screen'])
        self.env.observation_space = self.env.observation_space.spaces["screen"]
        # self.env = gymnasium.wrappers.GrayScaleObservation(self.env)
        # self.env = gymnasium.wrappers.ResizeObservation(self.env, shape=(64, 64))
        # self.env = gymnasium.wrappers.FrameStack(self.env, num_stack=num_stack)

        self.observation_space = gymnasium.spaces.Box(
            low=0, high=1, shape=(3, 64, 64), dtype=np.float32
        )
        self.action_space = self.env.action_space
        self.curr_step = 0
        self.curr_pos = np.array([0, 0])
        self.max_dist = np.linalg.norm(np.array([-0, -0]) - np.array([224, 224]))

    def step(self, action, savefig=False):
        self.curr_step += 1

        if action not in [0, 1, 2, 3]:
            raise ValueError(f"Invalid action: {action}")

        if action == 0:
            action = {"binary": 0, "continuous": np.array([0], dtype=np.float32)}
        elif action == 1:
            action = {"binary": 0, "continuous": np.array([-36], dtype=np.float32)}
        elif action == 2:
            action = {"binary": 0, "continuous": np.array([36], dtype=np.float32)}
        else:
            action = {"binary": 1, "continuous": np.array([0], dtype=np.float32)}

        # Forward the step to the wrapped environment
        obs, reward, done, truncated, info = self.env.step(action)
        obs, self.curr_pos = obs["screen"], obs["gamevariables"]
        self.curr_pos = np.array(self.curr_pos)

        obs = cv2.resize(obs, (64, 64), interpolation=cv2.INTER_AREA)
        obs = np.array(obs) / 255.0
        obs = obs.transpose(2, 0, 1)
        obs = obs.reshape((3, 64, 64))

        reward = self.reward()
        done = self.is_done()

        return obs, reward, done, info

    def get_num_actions(self):
        return 4

    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)
        obs, self.curr_pos = obs["screen"], obs["gamevariables"]
        self.curr_pos = np.array(self.curr_pos)

        obs = cv2.resize(obs, (64, 64), interpolation=cv2.INTER_AREA)
        obs = np.array(obs) / 255.0
        obs = obs.transpose(2, 0, 1)
        obs = obs.reshape((3, 64, 64))

        return obs, info

    def reward(self):
        # dist = np.linalg.norm([192,192]-self.curr_pos)/self.max_dist
        dist = np.linalg.norm(np.array([0, 0]) - self.curr_pos) / self.max_dist
        if dist < 0.1:
            reward = 10.0
        else:
            reward = -dist
        return reward

    def is_done(self):
        dist = np.linalg.norm(np.array([0, 0]) - self.curr_pos) / self.max_dist
        if dist < 0.1:
            done = True
        else:
            done = False
        return done

    def render(self, mode="human"):
        # Forward the render call to the wrapped environment
        return self.env.render(mode)

    def close(self):
        # Forward the close call to the wrapped environment
        self.env.close()

    def seed(self, seed=None):
        # Forward the seed call to the wrapped environment
        self.env.seed(seed)


def main(
    global_step=300000,
    seed=27,  # 27,33,5,55
    batch_size=250,
    learn_every=1,
    buffer_size=1000000,
    debug=False,
    device="cuda",
    lr_dqn=1e-4,
    target_network_frequency=10000,
    tau=1,
    exploration_fraction=0.1,
    gamma=0.99,
    rb_path=None,
):

    set_seed(seed)
    run = wandb.init(
        project="vizdoom_rl",
        # We pass a run name (otherwise it’ll be randomly assigned, like sunshine-lollypop-10)
        name=f"dqn_{seed}",
        # Track hyperparameters and run metadata
        config={
            "dqn_lr": lr_dqn,
            "architecture": "CNN",
            "global step": global_step,
            "seed": seed,
            "prior": "DQN",
            "batch_size": batch_size,
            "tau": tau,
            "target_network_frequency": target_network_frequency,
            "gamma": gamma,
        },
    )
    wandb.define_metric("dqn_step")
    wandb.define_metric("metrics/episode_reward", step_metric="dqn_step")
    wandb.define_metric("metrics/epsilon", step_metric="dqn_step")
    wandb.define_metric("losses/td_loss", step_metric="dqn_step")
    wandb.define_metric("losses/q_values", step_metric="dqn_step")

    env = VizdoomSingleRoom(render_mode=None, num_stack=1)

    agent = MBAgent(
        env=env,
        batch_size=batch_size,
        learn_every=learn_every,
        buffer_size=buffer_size,
        debug=debug,
        device=device,
        seed=seed,
        wandb_run=run,
        lr_dqn=lr_dqn,
        target_network_frequency=target_network_frequency,
        exploration_fraction=exploration_fraction,
        tau=tau,
        gamma=gamma,
        rb_path=rb_path,
    )

    agent.train(global_step=global_step)
    env.close()
    run.finish()


if __name__ == "__main__":
    seeds = [5, 55, 33, 27, 96]

    for seed in seeds:
        main(
            seed=seed,
            path_rb="/rl-transitions/src/envs/vizdoom/dataset/replay_buffer_rl_150k_hard.pickle",
        )
