import sys

sys.dont_write_bytecode = True

import os
import gym
import gymnasium
from vizdoom import gymnasium_wrapper
import math
import copy
import random
import datetime
import numpy as np
import cv2
from time import sleep
import wandb
import pickle

import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.lines import Line2D
import matplotlib
import pickle
import pandas as pd
from collections import deque, namedtuple

import torch
import torch.nn as nn
from torch.optim.lr_scheduler import CosineAnnealingLR, OneCycleLR, StepLR, CyclicLR
from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F
from transformers import get_cosine_schedule_with_warmup

from components.drawing import Arrow3D


def set_seed(seed):
    """Sets the seed of the entire notebook so results are the same every time we run.
    This is for REPRODUCIBILITY."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Set a fixed value for the hash seed
    os.environ["PYTHONHASHSEED"] = str(seed)

    print("Global seeds set to", seed)


Experience = namedtuple(
    "Experience",
    field_names=["state", "action", "reward", "next_state", "done", "state_hash"],
)


class ReplayBuffer:
    def __init__(self, buffer_size, batch_size, seed, memory=None):
        if memory == None:
            self.memory = deque(maxlen=buffer_size)
        else:
            self.memory = memory
        self.batch_size = batch_size

    def add(self, state, action, reward, next_state, done, state_hash):
        """Add a new experience to memory."""
        e = Experience(state, action, reward, next_state, done, state_hash)
        self.memory.append(e)

    def sample(self):
        experiences = random.sample(self.memory, k=self.batch_size)
        states = torch.from_numpy(
            np.vstack([e.state[None, :] for e in experiences if e is not None])
        ).float()
        actions = torch.from_numpy(
            np.vstack([e.action for e in experiences if e is not None])
        ).long()
        rewards = torch.from_numpy(
            np.vstack([e.reward for e in experiences if e is not None])
        ).float()
        next_states = torch.from_numpy(
            np.vstack([e.next_state[None, :] for e in experiences if e is not None])
        ).float()
        dones = torch.from_numpy(
            np.vstack([e.done for e in experiences if e is not None]).astype(np.uint8)
        ).float()

        return (states, actions, rewards, next_states, dones)

    def __len__(self):
        return len(self.memory)

    def save(self, file_path):
        """Save the replay buffer to a file."""
        with open(file_path, "wb") as f:
            pickle.dump(self.memory, f)
        print(f"Replay buffer saved to {file_path}.")

    def load(self, file_path):
        """Load the replay buffer from a file."""
        with open(file_path, "rb") as f:
            self.memory = pickle.load(f)
        print(f"Replay buffer loaded from {file_path}.")
        print(f"Replay buffer contains {len(self.memory)} experiences.")


class VizdoomSingleRoom(gymnasium.Env):
    def __init__(self, render_mode="human", num_stack=4):
        self.env = gymnasium.make("VizdoomSingleRoom-v0", render_mode=render_mode)
        self.num_stack = num_stack
        self.env.observation_space = self.env.observation_space.spaces["screen"]
        self.observation_space = gymnasium.spaces.Box(
            low=0, high=1, shape=(3, 64, 64), dtype=np.float32
        )
        self.action_space = self.env.action_space
        self.curr_step = 0
        self.curr_pos = np.array([0, 0])
        self.max_dist = np.linalg.norm(np.array([-224, -224]) - np.array([224, 224]))

    def step(self, action, savefig=False):
        self.curr_step += 1

        if action not in [0, 1, 2, 3]:
            raise ValueError(f"Invalid action: {action}")
        actions = {
            0: {"binary": 0, "continuous": np.array([0], dtype=np.float32)},
            1: {"binary": 0, "continuous": np.array([-36], dtype=np.float32)},
            2: {"binary": 0, "continuous": np.array([36], dtype=np.float32)},
            3: {"binary": 1, "continuous": np.array([0], dtype=np.float32)},
        }

        action = actions[action]
        obs, reward, done, truncated, info = self.env.step(action)
        obs, self.curr_pos = obs["screen"], obs["gamevariables"]
        self.curr_pos = np.array(self.curr_pos)

        obs = cv2.resize(obs, (64, 64), interpolation=cv2.INTER_AREA)
        obs = np.array(obs) / 255.0
        obs = obs.transpose(2, 0, 1)

        reward = self.reward()
        done = self.is_done()
        truncated = self.curr_step > 2500

        return obs, reward, done, truncated, info

    def get_num_actions(self):
        return 4

    def reset(self, **kwargs):
        self.curr_step = 0
        obs, info = self.env.reset(**kwargs)
        obs, self.curr_pos = obs["screen"], obs["gamevariables"]
        self.curr_pos = np.array(self.curr_pos)

        obs = cv2.resize(obs, (64, 64), interpolation=cv2.INTER_AREA)
        obs = np.array(obs) / 255.0
        obs = obs.transpose(2, 0, 1)

        return obs, info

    def reward(self):
        dist = np.linalg.norm(np.array([180, 180]) - self.curr_pos) / self.max_dist
        if dist < 0.1:
            reward = 10.0
        else:
            reward = -dist
        return reward

    def is_done(self):
        dist = np.linalg.norm(np.array([180, 180]) - self.curr_pos) / self.max_dist
        if dist < 0.1:
            done = True
        else:
            done = False
        return done

    def render(self, mode="human"):
        return self.env.render(mode)

    def close(self):
        self.env.close()

    def seed(self, seed=None):
        self.env.seed(seed)


def main(global_step=150000, seed=33, buffer_size=1000000, only_right=False):

    set_seed(seed)
    replays = ReplayBuffer(buffer_size=buffer_size, batch_size=250, seed=seed)
    env = VizdoomSingleRoom(render_mode=None, num_stack=1)
    state, state_hash = env.reset()

    ep = 0
    print(f"Episode", ep)
    for i in range(global_step):
        action = np.random.choice([0, 1, 2, 3])
        next_state, reward, done, truncated, next_state_hash = env.step(action)
        if not (only_right) or action != 1:
            replays.add(state, action, reward, next_state, done, state_hash)

        state = next_state
        state_hash = next_state_hash
        if done or truncated:
            ep += 1
            print(f"Episode", ep)
            state, state_hash = env.reset()
    env.close()

    if not os.path.exists("src/envs/dataset"):
        os.makedirs("src/envs/dataset")
    replays.save("src/envs/dataset/replay_buffer_rl.pickle")


if __name__ == "__main__":
    main()
