import sys

sys.dont_write_bytecode = True

import os
import gym
import gymnasium
from vizdoom import gymnasium_wrapper
import math
import copy
import random
import datetime
import numpy as np
import cv2
from time import sleep
import wandb
import pickle

import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.lines import Line2D
import matplotlib
import pickle
import pandas as pd
from collections import deque, namedtuple

import torch
import torch.nn as nn
from torch.optim.lr_scheduler import CosineAnnealingLR, OneCycleLR, StepLR, CyclicLR
from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F
from transformers import get_cosine_schedule_with_warmup

from components.drawing import Arrow3D

font = {"weight": "bold", "size": 11}

matplotlib.rc("font", **font)


def set_seed(seed):
    """Sets the seed of the entire notebook so results are the same every time we run.
    This is for REPRODUCIBILITY."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Set a fixed value for the hash seed
    os.environ["PYTHONHASHSEED"] = str(seed)

    print("Global seeds set to", seed)


Experience = namedtuple(
    "Experience",
    field_names=["state", "action", "reward", "next_state", "done", "state_hash"],
)


class ReplayBuffer:
    def __init__(self, buffer_size, batch_size, seed, memory=None):
        if memory == None:
            self.memory = deque(maxlen=buffer_size)
        else:
            self.memory = memory
        self.batch_size = batch_size

    def add(self, state, action, reward, next_state, done, state_hash):
        """Add a new experience to memory."""
        e = Experience(state, action, reward, next_state, done, state_hash)
        self.memory.append(e)

    def sample(self, batch_size=None):
        if batch_size is None:
            experiences = random.sample(self.memory, k=self.batch_size)
        else:
            experiences = random.sample(self.memory, k=batch_size)
        states = torch.from_numpy(
            np.vstack([e.state[None, :] for e in experiences if e is not None])
        ).float()
        actions = torch.from_numpy(
            np.vstack([e.action for e in experiences if e is not None])
        ).long()
        rewards = torch.from_numpy(
            np.vstack([e.reward for e in experiences if e is not None])
        ).float()
        next_states = torch.from_numpy(
            np.vstack([e.next_state[None, :] for e in experiences if e is not None])
        ).float()
        dones = torch.from_numpy(
            np.vstack([e.done for e in experiences if e is not None]).astype(np.uint8)
        ).float()

        return (states, actions, rewards, next_states, dones)

    def __len__(self):
        return len(self.memory)

    def save(self, file_path):
        """Save the replay buffer to a file."""
        with open(file_path, "wb") as f:
            pickle.dump(self.memory, f)
        print(f"Replay buffer saved to {file_path}.")

    def load(self, file_path):
        """Load the replay buffer from a file."""
        with open(file_path, "rb") as f:
            self.memory = pickle.load(f)
        print(f"Replay buffer loaded from {file_path}.")
        print(f"Replay buffer contains {len(self.memory)} experiences.")


class EncoderCNN(nn.Module):
    def __init__(self, output_dim, input_dim):
        super(EncoderCNN, self).__init__()
        self.num_channel, self.h, self.w = input_dim

        self.gate = nn.ELU()
        self.hidden = 256
        self.conv_encoder = nn.Sequential(
            nn.Conv2d(
                in_channels=self.num_channel,
                out_channels=32,
                kernel_size=(4, 4),
                stride=2,
            ),
            self.gate,
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(4, 4), stride=2),
            self.gate,
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(4, 4), stride=2),
            self.gate,
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(4, 4), stride=2),
            self.gate,
        )

        self.flatten_dim_after_conv = 256 * 2 * 2
        factor = 2
        self.fc_after_conv = nn.Sequential(
            nn.Linear(self.flatten_dim_after_conv, 128 * factor),
            self.gate,
            nn.Linear(128 * factor, 64 * factor),
            self.gate,
            nn.Linear(64 * factor, 32 * factor),
            self.gate,
            nn.Linear(32 * factor, 16 * factor),
            self.gate,
            nn.Linear(16 * factor, output_dim),
        )

    def forward(self, x):

        x = self.conv_encoder(x)
        x = torch.flatten(x, start_dim=1)
        x = self.fc_after_conv(x)

        return x


class Transition(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.output_dim = output_dim

        gate = nn.ELU()

        self.fc = nn.Sequential(
            nn.Linear(input_dim, 64), gate, nn.Linear(64, output_dim)
        )

    def forward(self, z, actions):
        d = self.fc(torch.cat([z, actions], dim=-1))
        d = d

        return d


class Reward(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.output_dim = output_dim
        hid_dim = 64
        factor = 1

        gate = nn.ELU()

        self.fc = nn.Sequential(
            nn.Linear(input_dim, 64 * factor), gate, nn.Linear(64 * factor, output_dim)
        )

    def forward(self, z, actions):
        r = self.fc(torch.cat([z, actions], dim=-1))

        return r


class QNetwork(nn.Module):
    def __init__(self, env):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(3, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, 4),
        )

    def forward(self, x):
        return self.network(x)


class MBAgent:
    def __init__(
        self,
        env: gym.Env,
        batch_size: int,
        device: str = "cpu",
        debug: bool = False,
        buffer_size=10000,
        seed=2022,
        gamma=0.99,
        tau=1,
        target_network_frequency=500,
        train_frequency=10,
        lr_dqn=2.5e-4,
        path_rb=None,
        latent_dim=3,
    ):
        self.env = env
        self.device = device
        self.batch_size = batch_size
        self.debug = debug
        self.replays = ReplayBuffer(
            buffer_size=buffer_size, batch_size=batch_size, seed=seed
        )
        if path_rb is not None:
            self.replays.load(path_rb)
        self.seed = seed

        ch, w, h = self.env.observation_space.shape
        self.num_actions = self.env.get_num_actions()

        self.encoder = EncoderCNN(
            input_dim=self.env.observation_space.shape, output_dim=latent_dim
        ).to(self.device)
        self.transition = Transition(
            input_dim=latent_dim + self.num_actions, output_dim=latent_dim
        ).to(self.device)
        self.reward_model = Reward(
            input_dim=latent_dim + self.num_actions, output_dim=1
        ).to(self.device)

        # DQN parameters
        self.gamma = gamma
        self.tau = tau
        self.target_network_frequency = target_network_frequency
        self.train_frequency = train_frequency
        self.lr_dqn = lr_dqn

        self.q_network = QNetwork(env).to(device)
        self.optimizer_dqn = torch.optim.Adam(self.q_network.parameters(), lr=lr_dqn)
        self.target_network = QNetwork(env).to(device)
        self.target_network.load_state_dict(self.q_network.state_dict())

    def load_world_model(self, path):
        self.transition.load_state_dict(
            torch.load(os.path.join(path, f"transition.pt"), weights_only=True)
        )
        self.encoder.load_state_dict(
            torch.load(os.path.join(path, f"encoder.pt"), weights_only=True)
        )
        self.reward_model.load_state_dict(
            torch.load(os.path.join(path, f"reward.pt"), weights_only=True)
        )

    def generate_trajectory(self, states, num_steps):

        with torch.no_grad():
            z = self.encoder(states)
            z[:, 0] = torch.remainder(z[:, 0], 2 * np.pi)

            actions = torch.arange(self.num_actions, device=self.device).view(-1, 1)
            one_hot_actions = torch.zeros(
                self.num_actions, self.num_actions, device=self.device
            )
            one_hot_actions.scatter_(1, actions, 1)

            z_expanded = (
                z.unsqueeze(0).expand(self.num_actions, -1, -1).reshape(-1, z.size(1))
            )
            one_hot_actions_expanded = one_hot_actions.repeat(z.size(0), 1)

            d = self.transition(z_expanded, one_hot_actions_expanded)
            next_z = z_expanded + d
            next_z[:, 0] = torch.remainder(next_z[:, 0], 2 * np.pi)
            rewards = self.reward_model(z_expanded, one_hot_actions_expanded)
            dones = torch.zeros_like(rewards, device=self.device)

            states = z_expanded
            actions = actions.repeat_interleave(z.size(0)).view(-1, 1)
            rewards = rewards.view(-1, 1)
            next_states = next_z
            dones = dones.view(-1, 1)

        return states, actions, rewards, next_states, dones

    def learn_dqn(self, step=0):
        states, actions, rewards, next_states, dones = self.replays.sample()
        states = states.to(self.device)
        next_states = next_states.to(self.device)
        actions = actions.to(self.device)
        rewards = rewards.to(self.device)
        dones = dones.to(self.device)

        with torch.no_grad():
            z = self.encoder(states)
            z[:, 0] = torch.remainder(z[:, 0], 2 * np.pi)
            next_z = self.encoder(next_states)
            next_z[:, 0] = torch.remainder(next_z[:, 0], 2 * np.pi)

        z_synth, actions_synth, rewards_synth, next_z_synth, dones_synth = (
            self.generate_trajectory(states, num_steps=1)
        )

        states = torch.cat([z, z_synth], dim=0)
        actions = torch.cat([actions, actions_synth], dim=0)
        rewards = torch.cat([rewards, rewards_synth], dim=0)
        next_states = torch.cat([next_z, next_z_synth], dim=0)
        dones = torch.cat([dones, dones, dones, dones, dones], dim=0)

        _, idx_target = self.q_network(states.detach()).max(dim=1)

        with torch.no_grad():

            target_max = (
                self.target_network(next_states)
                .gather(1, idx_target.unsqueeze(1))
                .squeeze()
            )
            td_target = rewards.detach().flatten() + self.gamma * target_max * (
                1 - dones.flatten()
            )
        old_val = self.q_network(states.detach()).gather(1, actions).squeeze()
        loss = F.mse_loss(td_target, old_val)

        wandb.log(
            {"losses/td_loss": loss.item(), "losses/q_values": old_val.mean().item()},
            step=step,
        )

        # optimize the model
        self.optimizer_dqn.zero_grad()
        loss.backward()
        self.optimizer_dqn.step()

        # update target network
        if step % self.target_network_frequency == 0:
            for target_network_param, q_network_param in zip(
                self.target_network.parameters(), self.q_network.parameters()
            ):
                target_network_param.data.copy_(
                    self.tau * q_network_param.data
                    + (1.0 - self.tau) * target_network_param.data
                )

    def evaluate_dqn(self):
        cum_reward = 0
        env = VizdoomSingleRoom(render_mode=None)
        state, state_hash = env.reset()
        for _ in range(2500):
            z = self.encoder(torch.Tensor(state[None]).float().to(self.device))
            z[:, 0] = torch.remainder(z[:, 0], 2 * np.pi)
            q_values = self.q_network(z)
            action = torch.argmax(q_values, dim=1).cpu().numpy()
            next_state, reward, done, truncated, next_state_hash = env.step(action)
            cum_reward += reward

            state = next_state
            state_hash = next_state_hash
            if done:
                break
        env.close()
        return cum_reward

    def train(self, global_step: int):

        for step in range(global_step):
            if step % 1000 == 0:
                print("training dqn step: ", step)
                cum_reward = self.evaluate_dqn()
                wandb.log({"metrics/episode_reward": cum_reward}, step=step)
            self.learn_dqn(step=step)


class VizdoomSingleRoom(gymnasium.Env):
    def __init__(self, render_mode="human", num_stack=4):
        self.env = gymnasium.make("VizdoomSingleRoom-v0", render_mode=render_mode)
        self.num_stack = num_stack
        self.env.observation_space = self.env.observation_space.spaces["screen"]
        self.observation_space = gymnasium.spaces.Box(
            low=0, high=1, shape=(3, 64, 64), dtype=np.float32
        )
        self.action_space = self.env.action_space
        self.curr_step = 0
        self.curr_pos = np.array([0, 0])
        self.max_dist = np.linalg.norm(np.array([-224, -224]) - np.array([224, 224]))

    def step(self, action, savefig=False):
        self.curr_step += 1

        if isinstance(action, np.ndarray):
            action = int(action.item())

        if action not in [0, 1, 2, 3]:
            raise ValueError(f"Invalid action: {action}")
        actions = {
            0: {"binary": 0, "continuous": np.array([0], dtype=np.float32)},
            1: {"binary": 0, "continuous": np.array([-36], dtype=np.float32)},
            2: {"binary": 0, "continuous": np.array([36], dtype=np.float32)},
            3: {"binary": 1, "continuous": np.array([0], dtype=np.float32)},
        }

        action = actions[action]
        obs, reward, done, truncated, info = self.env.step(action)
        obs, self.curr_pos = obs["screen"], obs["gamevariables"]
        self.curr_pos = np.array(self.curr_pos)

        obs = cv2.resize(obs, (64, 64), interpolation=cv2.INTER_AREA)
        obs = np.array(obs) / 255.0
        obs = obs.transpose(2, 0, 1)

        reward = self.reward()
        done = self.is_done()
        truncated = self.curr_step > 2500

        return obs, reward, done, truncated, info

    def get_num_actions(self):
        return 4

    def reset(self, **kwargs):
        self.curr_step = 0
        obs, info = self.env.reset(**kwargs)
        obs, self.curr_pos = obs["screen"], obs["gamevariables"]
        self.curr_pos = np.array(self.curr_pos)

        obs = cv2.resize(obs, (64, 64), interpolation=cv2.INTER_AREA)
        obs = np.array(obs) / 255.0
        obs = obs.transpose(2, 0, 1)

        return obs, info

    def reward(self):
        dist = np.linalg.norm(np.array([180, 180]) - self.curr_pos) / self.max_dist
        if dist < 0.1:
            reward = 10.0
        else:
            reward = -dist
        return reward

    def is_done(self):
        dist = np.linalg.norm(np.array([180, 180]) - self.curr_pos) / self.max_dist
        if dist < 0.1:
            done = True
        else:
            done = False
        return done

    def render(self, mode="human"):
        return self.env.render(mode)

    def close(self):
        self.env.close()

    def seed(self, seed=None):
        self.env.seed(seed)


def main(
    global_step=300000,
    seed=27,
    batch_size=250,
    latent_dim=3,
    buffer_size=1000000,
    debug=False,
    device="cuda",
    lr_dqn=5e-5,
    target_network_frequency=5000,
    tau=1,
    exploration_fraction=0.5,
    gamma=0.9,
    path_rb=None,
    path_weight=None,
):

    set_seed(seed)
    run = wandb.init(
        project="vizdoom_rl",
        name=f"dqn_encoder_{seed}",
        config={
            "global step": global_step,
            "seed": seed,
            "prior": "test_new",
            "lr dqn": lr_dqn,
            "target network frequency": target_network_frequency,
            "tau": tau,
            "gamma": gamma,
            "batch_size": batch_size,
        },
    )
    env = VizdoomSingleRoom(render_mode=None)

    agent = MBAgent(
        env=env,
        latent_dim=latent_dim,
        batch_size=batch_size,
        buffer_size=buffer_size,
        debug=debug,
        device=device,
        seed=seed,
        lr_dqn=lr_dqn,
        target_network_frequency=target_network_frequency,
        tau=tau,
        gamma=gamma,
        path_rb=path_rb,
    )

    if path_weight is not None:
        print("Loading weights from: ", path_weight)
        agent.load_world_model(path_weight)

    agent.train(global_step=global_step)
    env.close()
    run.finish()


if __name__ == "__main__":
    seeds = [5, 55, 33, 27, 96]
    for seed in seeds:
        main(
            seed=seed,
            path_rb="rl-transitions/src/envs/vizdoom/dataset/replay_buffer_rl_150k_hard.pickle",
            path_weight="rl-transition/src/envs/vizdoom/weight/150k_trans_100k_steps_hard_3",
        )
