import copy
from eztils.torch import soft_update_from_to
import gym
import numpy as np
import torch
from gridworld import SimpleGridWorld
from torchrl.data import ReplayBuffer, LazyTensorStorage
from tensordict.tensordict import TensorDict


def evaluate_policy(policy, env, num_episodes=100):
    """
    Evaluate the performance of a policy in the given environment.

    Args:
        policy: Function or model representing the policy.
        env: Environment instance.
        num_episodes: Number of episodes to evaluate.

    Returns:
        List of average episode rewards over the evaluation episodes.
    """
    env.to_test()
    episode_rewards = []
    for _ in range(num_episodes):
        obs = env.reset()
        episode_reward = 0
        done = False
        step = 0
        while not done:
            action = policy[obs]
            obs, reward, done, _ = env.step(action)
            episode_reward += reward
            step += 1
        episode_rewards.append(round(float(episode_reward), 3))

    env.to_train()
    return episode_rewards


class DQN(torch.nn.Module):
    def __init__(self, state_shape, action_shape, learning_rate=0.001, batch_size=1024):
        """The agent maps X-states to Y-actions
        e.g. The neural network output is [.1, .7, .05, 0.05, .05, .05]
        The highest value 0.7 is the Q-Value.
        The index of the highest action (0.7) is action #1.
        """
        super().__init__()

        self.model = torch.nn.Sequential(
            torch.nn.Linear(state_shape, 24),
            torch.nn.ReLU(),
            torch.nn.Linear(24, 12),
            torch.nn.ReLU(),
            torch.nn.Linear(12, action_shape),
        )
        # clone model weights
        self.target_model = copy.deepcopy(self.model)
        # freeze target model
        for param in self.target_model.parameters():
            param.requires_grad = False

        self.loss = torch.nn.HuberLoss()
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate)

        self.batch_size = batch_size
        self.replay_memory = ReplayBuffer(
            storage=LazyTensorStorage(max_size=50_000), batch_size=batch_size
        )

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def train(self, env, learning_rate=1, discount_factor=0.99, verbose=False):
        mini_batch = self.replay_memory.sample(self.batch_size).to(self.device)

        current_states = torch.nn.functional.one_hot(
            mini_batch["state"], num_classes=env.observation_space.n
        ).squeeze()
        current_qs = self.model(current_states.float())

        new_current_states = torch.nn.functional.one_hot(
            mini_batch["next_state"], num_classes=env.observation_space.n
        ).squeeze()
        future_qs = self.target_model(new_current_states.float())

        max_future_qs, _ = torch.max(future_qs, dim=1, keepdim=True)

        # Compute the target Q values
        max_future_qs = mini_batch[
            "reward"
        ].float() + discount_factor * max_future_qs * (1 - mini_batch["done"].float())

        # Update Q values for actions taken
        Q_targets = current_qs.clone().detach()
        Q_targets.scatter_(
            1,
            mini_batch["action"],
            torch.full_like(max_future_qs, 1 - learning_rate),
            reduce="multiply",
        )
        Q_targets.scatter_(
            1, mini_batch["action"], learning_rate * max_future_qs, reduce="add"
        )

        if verbose:
            print("current_qs", current_qs)
            print("max_future_qs", max_future_qs)
            print("Q_targets", Q_targets)
            print()

        # Train the model (Assuming self.fit will be replaced by a standard PyTorch training step)
        return self.fit(
            current_qs[mini_batch["action"]],
            Q_targets[mini_batch["action"]],
            verbose=0,
            shuffle=True,
        )

    def get_action(self, obs):
        current_qs = self.target_model(obs.to(self.device))
        return current_qs.argmax(dim=1)

    def fit(self, X, y, verbose, shuffle):
        self.optimizer.zero_grad()
        loss = self.loss(X, y)  # + (X **2).mean() * 1e-2
        loss.backward()
        self.optimizer.step()
        return loss

    def generate_policy(self, env):
        # generate policy based on trained model
        action = self.get_action(
            DQN.encode_state(list(range(env.observation_space.n)), env).float()
        )
        return action.tolist()

    @staticmethod
    def encode_state(obs, env):
        return torch.nn.functional.one_hot(
            torch.tensor(obs), num_classes=env.observation_space.n
        )

    def run(
        self,
        env,
        train_steps=10_000,
        tau=0.1,
        max_epsilon=0.1,
        min_epsilon=0.1,
        decay=0.01,
        update_target_every=8,
    ):
        episode = 0
        total_time_steps = 0
        returns = []
        losses = []
        epsilon = max_epsilon

        # for episode in range(train_episodes):
        while total_time_steps < train_steps:
            total_training_rewards = 0
            observation = env.reset()
            done = False

            episode_time_steps = 0
            while not done:
                total_time_steps += 1
                episode_time_steps += 1

                # if True:
                #    env.render()

                random_number = np.random.rand()
                action = None
                # 2. Explore using the Epsilon Greedy Exploration Strategy
                if random_number <= epsilon:
                    # Explore
                    action = env.action_space.sample()
                else:
                    # Exploit best known action
                    # model dims are (batch, env.observation_space.n)
                    encoded = DQN.encode_state(observation, env)
                    encoded_reshaped = encoded.reshape([1, encoded.shape[0]]).float()
                    predicted = self.model(encoded_reshaped.to(self.device)).flatten()
                    action = np.argmax(predicted.cpu().detach().numpy())
                new_observation, reward, done, info = env.step(action)

                self.replay_memory.extend(
                    TensorDict(
                        dict(
                            state=torch.tensor([[observation]]),
                            action=torch.tensor([[action]]),
                            reward=torch.tensor([[reward]]),
                            next_state=torch.tensor([[new_observation]]),
                            done=torch.tensor([[done]]),
                        ),
                        batch_size=1,
                    ).to(self.device)
                )

                # 3. Update the Main Network using the Bellman Equation
                if total_time_steps % 4 == 0 or done:
                    if episode_time_steps % 100 == 0:
                        verbose = True
                    else:
                        verbose = False
                    losses.append(self.train(env, verbose=False).item())
                    if (
                        total_time_steps >= 100
                        and total_time_steps % update_target_every == 0
                    ):
                        soft_update_from_to(self.model, self.target_model, tau=tau)

                observation = new_observation
                total_training_rewards += reward

                if total_time_steps % 100 == 0:
                    print("Evaluation at timestep {}".format(total_time_steps), end="")
                    returns.append(
                        evaluate_policy(self.generate_policy(env), env, num_episodes=1)[
                            0
                        ]
                    )
                    print(": return {}".format(returns[-1]))

                if done:
                    # print(
                    #     "Total training return: {} after n episodes = {} with num timesteps = {} (epsilon = {})".format(
                    #         round(total_training_rewards, 2),
                    #         episode + 1,
                    #         episode_time_steps,
                    #         epsilon,
                    #     )
                    # )
                    # print(total_time_steps, end=" ")
                    break

                if total_time_steps == train_steps:
                    break

            epsilon = min_epsilon + (max_epsilon - min_epsilon) * np.exp(
                -decay * episode
            )
            episode += 1

        env.close()
        return returns, losses


class GymWrapper(gym.Env):
    def __init__(
        self,
        world: SimpleGridWorld,
        train_reward,
        true_reward,
        terminal,
        time_limit=None,
    ):
        self.world = world
        self.train_reward = train_reward
        self.true_reward = true_reward
        self.terminal = terminal
        self.action_space = gym.spaces.Discrete(self.world.num_actions)
        self.observation_space = gym.spaces.Discrete(self.world.num_states)
        self.current_state = None
        self.training = True
        self.time_limit = time_limit

    def reset(self):
        # Reset to the initial state (index 0)
        self.current_state = 0
        self.episode_timesteps = 0
        return self.current_state
        # return one_hot(self.current_state, self.world.n_states)

    def to_train(self):
        self.training = True

    def to_test(self):
        self.training = False

    def step(self, action):
        # Implement the step function based on your IcyGridWorld logic
        # Update self.current_state with the next state after taking the action
        next_state = self._sample_transition(self.current_state, action)

        # Calculate reward based on your reward function

    def step(self, action):
        # Implement the step function based on your IcyGridWorld logic
        # Update self.current_state with the next state after taking the action
        next_state = self._sample_transition(self.current_state, action)

        # Calculate reward based on your reward function
        reward = (
            self.train_reward[self.current_state, action]
            if self.training
            else self.true_reward[self.current_state, action]
        )
        # reward = self.reward[self.current_state]

        # Check if the next state is a terminal state
        done = self.world.is_terminal(self.current_state)

        # Update the current state for the next step
        self.current_state = next_state

        # You can define info if needed

        # Increment the episode timestep
        self.episode_timesteps += 1
        if self.time_limit is not None and self.episode_timesteps >= self.time_limit:
            done = True

        return self.current_state, reward, done, {}

    def _sample_transition(self, state, action):
        action_env = int(action)
        return self.world.step(state, action_env)


if __name__ == "__main__":
    # TODO: make sure discount you use for DQN matches what you use for value iteration
    world = SimpleGridWorld(5, debug=False)
    env = GymWrapper(
        world,
        train_reward=world.rewards,
        true_reward=world.rewards,
        terminal=world.num_states - 1,
        time_limit=100,
    )
    dqn = DQN(env.observation_space.n, env.action_space.n, learning_rate=1e-3)
    if torch.cuda.is_available():
        dqn = dqn.cuda()

    eval_returns, losses = dqn.run(
        env,
        max_epsilon=0.01,
        min_epsilon=0.01,
        decay=0,
        train_steps=30_000,
        update_target_every=500,
        tau=1,
    )
