from typing import Dict, List, Tuple

import gym
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.optim as optim
# from IPython.display import clear_output
from torch.nn.utils import clip_grad_norm_
from tqdm import trange
import os
import cv2
from datetime import datetime



from utils import Network, PrioritizedReplayBuffer, ReplayBuffer, ResultRecorder
from attacker import Attacker



class DQNAgent:
    """DQN Agent interacting with environment.
    
    Attribute:
        env (gym.Env): openAI Gym environment
        memory (PrioritizedReplayBuffer): replay memory to store transitions
        batch_size (int): batch size for sampling
        target_update (int): period for target model's hard update
        gamma (float): discount factor
        dqn (Network): model to train and select actions
        dqn_target (Network): target model to update
        optimizer (torch.optim): optimizer for training dqn
        transition (list): transition information including 
                           state, action, reward, next_state, done
        v_min (float): min value of support
        v_max (float): max value of support
        atom_size (int): the unit number of support
        support (torch.Tensor): support for categorical dqn
        use_n_step (bool): whether to use n_step memory
        n_step (int): step number to calculate n-step td error
        memory_n (ReplayBuffer): n-step replay buffer
        folder_name (string): store the name of current folder
    """

    def __init__(
        self, 
        env: gym.Env,
        memory_size: int,
        batch_size: int,
        target_update: int,
        seed: int,
        gamma: float = 0.99,
        # PER parameters
        alpha: float = 0.2,
        beta: float = 0.6,
        prior_eps: float = 1e-6,
        # Categorical DQN parameters
        v_min: float = 0.0,
        v_max: float = 500.0,
        atom_size: int = 51,
        # N-step Learning
        n_step: int = 3,
        folder_name: str = "cartPole_exp",
        is_test: bool = False,
        is_poison: bool = False
    ):
        """Initialization.
        
        Args:
            env (gym.Env): openAI Gym environment
            memory_size (int): length of memory
            batch_size (int): batch size for sampling
            target_update (int): period for target model's hard update
            lr (float): learning rate
            gamma (float): discount factor
            alpha (float): determines how much prioritization is used
            beta (float): determines how much importance sampling is used
            prior_eps (float): guarantees every transition can be sampled
            v_min (float): min value of support
            v_max (float): max value of support
            atom_size (int): the unit number of support
            n_step (int): step number to calculate n-step td error
        """
        obs_dim = env.observation_space.shape[0]
        action_dim = env.action_space.n
        
        self.env = env
        self.batch_size = batch_size
        self.target_update = target_update
        self.seed = seed
        self.gamma = gamma
        # NoisyNet: All attributes related to epsilon are removed
        self.folder_name = folder_name
        
        # device: cpu / gpu
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu"
        )
        print(self.device)
        
        # PER
        # memory for 1-step Learning
        self.beta = beta
        self.prior_eps = prior_eps
        self.memory = PrioritizedReplayBuffer(
            obs_dim, memory_size, batch_size, alpha=alpha, gamma=gamma
        )
        
        # memory for N-step Learning
        self.use_n_step = True if n_step > 1 else False
        if self.use_n_step:
            self.n_step = n_step
            self.memory_n = ReplayBuffer(
                obs_dim, memory_size, batch_size, n_step=n_step, gamma=gamma
            )
            
        # Categorical DQN parameters
        self.v_min = v_min
        self.v_max = v_max
        self.atom_size = atom_size
        self.support = torch.linspace(
            self.v_min, self.v_max, self.atom_size
        ).to(self.device)

        # networks: dqn, dqn_target
        self.dqn = Network(
            obs_dim, action_dim, self.atom_size, self.support
        ).to(self.device)
        self.dqn_target = Network(
            obs_dim, action_dim, self.atom_size, self.support
        ).to(self.device)
        self.dqn_target.load_state_dict(self.dqn.state_dict())
        self.dqn_target.eval()
        
        # optimizer
        self.optimizer = optim.Adam(self.dqn.parameters())

        # transition to store in memory
        self.transition = list()
        
        # mode: train / test
        self.is_test = is_test

        self.is_poison = is_poison

        # Attacker
        if self.is_poison:
            self.attacker = Attacker(dim_state=obs_dim, n_actions=action_dim, gamma=self.gamma)
        self.recorder = ResultRecorder(is_test=is_test, folder_name=self.folder_name, input_dim=obs_dim)

    def select_action(self, state: np.ndarray) -> np.ndarray:
        """Select an action from the input state."""
        # NoisyNet: no epsilon greedy action selection
        selected_action = self.dqn(
            torch.FloatTensor(state).to(self.device)
        ).argmax()
        selected_action = selected_action.detach().cpu().numpy()
        
        if not self.is_test:
            self.transition = [state, selected_action]
        
        return selected_action

    def step(self, action: np.ndarray) -> Tuple[np.ndarray, np.float64, bool]:
        """Take an action and return the response of the env."""
        result = self.env.step(action)
        if isinstance(result, tuple) and len(result) == 5:
            next_state, reward, terminated, truncated, _ = result
            done = terminated or truncated
        else:
            next_state, reward, done = result[:3]
        reward -= np.abs(next_state[2])
        reward -= np.abs(next_state[0]/10.0)

        # done = terminated or truncated
        
        if not self.is_test:
            self.transition += [reward, next_state, done]
            
            # N-step transition
            if self.use_n_step:
                one_step_transition = self.memory_n.store(*self.transition)
            # 1-step transition
            else:
                one_step_transition = self.transition

            # add a single step transition
            if one_step_transition:
                self.memory.store(*one_step_transition)
    
        return next_state, reward, done

    def update_model(self) -> torch.Tensor:
        """Update the model by gradient descent."""
        # PER needs beta to calculate weights
        samples = self.memory.sample_batch(self.beta)
        weights = torch.FloatTensor(
            samples["weights"].reshape(-1, 1)
        ).to(self.device)
        indices = samples["indices"]

        # Attacker modifies the reward
        if self.is_poison:
            samples['rews'] = self.attacker.learn(samples['obs'], samples['acts'], samples['rews'], samples['next_obs'], samples['done'])

        # 1-step Learning loss
        elementwise_loss = self._compute_dqn_loss(samples, self.gamma)
        
        # PER: importance sampling before average
        loss = torch.mean(elementwise_loss * weights)
        
        # N-step Learning loss
        # we are gonna combine 1-step loss and n-step loss so as to
        # prevent high-variance. The original rainbow employs n-step loss only.
        if self.use_n_step:
            gamma = self.gamma ** self.n_step
            samples = self.memory_n.sample_batch_from_idxs(indices)
            elementwise_loss_n_loss = self._compute_dqn_loss(samples, gamma)
            elementwise_loss += elementwise_loss_n_loss
            
            # PER: importance sampling before average
            loss = torch.mean(elementwise_loss * weights)

        self.optimizer.zero_grad()
        loss.backward()
        clip_grad_norm_(self.dqn.parameters(), 10.0)
        self.optimizer.step()
        
        # PER: update priorities
        loss_for_prior = elementwise_loss.detach().cpu().numpy()
        new_priorities = loss_for_prior + self.prior_eps
        self.memory.update_priorities(indices, new_priorities)
        
        # NoisyNet: reset noise
        self.dqn.reset_noise()
        self.dqn_target.reset_noise()

        return loss.item()
        
    def train(self, num_frames: int):
        """Train the agent."""
        self.is_test = False
        
        state = self.env.reset(seed=self.seed)
        update_cnt = 0
        losses = []
        scores = []
        score = 0
        
        for frame_idx in trange(1, num_frames + 1):
            action = self.select_action(state)
            next_state, reward, done = self.step(action)

            state = next_state
            score += reward
            # NoisyNet: removed decrease of epsilon
            
            # PER: increase beta
            fraction = min(frame_idx / num_frames, 1.0)
            self.beta = self.beta + fraction * (1.0 - self.beta)

            # if episode ends
            if done:
                state = self.env.reset(seed=self.seed)
                scores.append(score)
                score = 0

            # if training is ready
            if len(self.memory) >= self.batch_size:
                loss = self.update_model()
                losses.append(loss)
                update_cnt += 1
                
                # if hard update is needed
                if update_cnt % self.target_update == 0:
                    self._target_hard_update()

                # Record necessary information into recorder
                if self.is_poison:
                    self.recorder.record(self.dqn, self.attacker.reward, self.attacker.Q, loss,
                                     self.attacker.rho, self.attacker.optim_Q.param_groups[0]['lr'], update_cnt)
            
            if self.is_poison:
                if frame_idx % 1000 == 0 or frame_idx == 1:
                    # 1. Create a subfolder named by eps value
                    subfolder = f"model_epoch_{self.attacker.eps}"
                    dir_path = os.path.join(os.getcwd(), self.folder_name, subfolder)
                    os.makedirs(dir_path, exist_ok=True)
                    # 2. Get current timestamp
                    timestr = datetime.now().strftime("%Y%m%d_%H%M%S")
                    # 3. Save model checkpoint
                    path = os.path.join(dir_path, f"model_epoch_{frame_idx}_{timestr}.pth")
                    torch.save({
                        'dqn_state_dict': self.dqn.state_dict(),
                        'attacker_reward_state_dict': self.attacker.reward.state_dict(),
                        'attacker_Q_state_dict': self.attacker.Q.state_dict(),
                    }, path)
                    print(f"[frame {frame_idx}][{timestr}] Model saved to {path}")

            # plotting at end
            if frame_idx == num_frames:
                self._plot(frame_idx, scores, losses)
                self._save_network()
                
        self.env.close()
        self.recorder.flush()

    def _save_network(self):
        """ Save trained network """
        dir_path = os.path.join( os.getcwd(), self.folder_name )
        path = os.path.join(dir_path, f"network_{self.attacker.eps}.pt") # Include both learner and attacker;'s network
        torch.save({
            'dqn_state_dict': self.dqn.state_dict(),
            'attacker_reward_state_dict': self.attacker.reward.state_dict(),
            'attacker_Q_state_dict': self.attacker.Q.state_dict(),
            }, path)
                
    def test(self, is_record_video=False, max_episode_steps=500) -> None:
        """Test the agent."""
        path = os.path.join(os.getcwd(), "network-good.pt") # network-good.pt
        initial_state = [1.0, 0.0, 0.0, 0.0]  # Initial state for testing

        checkpoint = torch.load(path)
        self.dqn.load_state_dict(checkpoint["dqn_state_dict"])
        # self.attacker.reward.load_state_dict(checkpoint["attacker_reward_state_dict"])
        # self.attacker.Q.load_state_dict(checkpoint["attacker_Q_state_dict"])
        # self.is_test = True
        
        if not is_record_video:
            naive_env = self.env
            max_pos = 0
            done = False
            state = self.env.reset(seed=self.seed)
            self.env.state = initial_state  # Directly update the state
            state = self.env.state
            score = 0
            ite = 0
            while not done:
                if np.abs(state[0]) >= max_pos:
                    max_pos = np.abs(state[0])
                action = self.select_action(state)
                next_state, reward, done = self.step(action)
                print("ite, current state is", ite, state, done)
                self.env.render()

                state = next_state
                score += reward
                ite += 1
                if ite >= max_episode_steps:
                    break
            print("The maximum pos is {}".format(max_pos))
            print("Final pos is {}".format(state[0]))
        else:        
            # For recording a video
            naive_env = self.env

            frame_width = 600
            frame_height = 400
            out = cv2.VideoWriter('cartpole_with_x_coord.mp4', cv2.VideoWriter_fourcc(*'H264'), 20, (frame_width, frame_height))

            done = False
            score = 0
            state = self.env.reset(seed=self.seed)
            self.env.state = initial_state  # Directly update the state
            state = self.env.state
            ite = 0
            print("self.env.state is", self.env.state)

            while not done:
                # Render environment and get video frame
                frame = self.env.render()
                # Extract x coordinate
                x = state[0]
                print("frame is,", frame.shape)
                # Overlay information on CartPole image
                cv2.putText(frame, f'x = {x:.2f}, rewards = {score:.2f}', (20, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0,0,0), 2, cv2.LINE_AA)
                cv2.putText(frame, f'Trigger: > 0.5, move right', (20, 80), cv2.FONT_HERSHEY_SIMPLEX, 1, (0,0,0), 2, cv2.LINE_AA)
                out.write(frame)

                action = self.select_action(state)
                next_state, reward, done = self.step(action)
                print("Current state is,", ite, state, action, done)
                print("Now next state is, ", next_state)

                state = next_state
                score += reward
                ite += 1
                if ite >= max_episode_steps:
                    break
        print("score: ", score)
        self.env.close()
        if is_record_video:
            out.release()
            cv2.destroyAllWindows() 
        # Reset environment
        self.env = naive_env

    def _compute_dqn_loss(self, samples: Dict[str, np.ndarray], gamma: float) -> torch.Tensor:
        """Return categorical dqn loss."""
        device = self.device  # for shortening the following lines
        state = torch.FloatTensor(samples["obs"]).to(device)
        next_state = torch.FloatTensor(samples["next_obs"]).to(device)
        action = torch.LongTensor(samples["acts"]).to(device)
        reward = torch.FloatTensor(samples["rews"].reshape(-1, 1)).to(device)
        done = torch.FloatTensor(samples["done"].reshape(-1, 1)).to(device)
        
        # Categorical DQN algorithm
        delta_z = float(self.v_max - self.v_min) / (self.atom_size - 1)

        with torch.no_grad():
            # Double DQN
            next_action = self.dqn(next_state).argmax(1)
            next_dist = self.dqn_target.dist(next_state)
            next_dist = next_dist[range(self.batch_size), next_action]

            t_z = reward + (1 - done) * gamma * self.support
            t_z = t_z.clamp(min=self.v_min, max=self.v_max)
            b = (t_z - self.v_min) / delta_z
            l = b.floor().long()
            u = b.ceil().long()

            offset = (
                torch.linspace(
                    0, (self.batch_size - 1) * self.atom_size, self.batch_size
                ).long()
                .unsqueeze(1)
                .expand(self.batch_size, self.atom_size)
                .to(self.device)
            )

            proj_dist = torch.zeros(next_dist.size(), device=self.device)
            proj_dist.view(-1).index_add_(
                0, (l + offset).view(-1), (next_dist * (u.float() - b)).view(-1)
            )
            proj_dist.view(-1).index_add_(
                0, (u + offset).view(-1), (next_dist * (b - l.float())).view(-1)
            )

        dist = self.dqn.dist(state)
        log_p = torch.log(dist[range(self.batch_size), action])
        elementwise_loss = -(proj_dist * log_p).sum(1)

        return elementwise_loss

    def _target_hard_update(self):
        """Hard update: target <- local."""
        self.dqn_target.load_state_dict(self.dqn.state_dict())
                
    def _plot(
        self, 
        frame_idx: int, 
        scores: List[float], 
        losses: List[float],
    ):
        """Plot the training progresses."""
        # clear_output(True)
        plt.close()
        plt.figure(figsize=(20, 5))
        plt.subplot(131)
        plt.title('frame %s. score: %s' % (frame_idx, np.mean(scores[-10:])))
        plt.plot(scores)
        plt.subplot(132)
        plt.title('loss')
        plt.plot(losses)
        dir_path = os.path.join( os.getcwd(), self.folder_name )
        path = os.path.join(dir_path, "dqn_result.png") # Include both learner and attacker;'s network
        plt.savefig(path)