# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import logging
import dataclasses
from collections import OrderedDict
import typing as tp

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from hydra.core.config_store import ConfigStore

from url_benchmark import utils
from url_benchmark.in_memory_replay_buffer import ReplayBuffer
from url_benchmark.dmc import TimeStep
from .ddpg import MetaDict, Encoder
from copy import deepcopy


logger = logging.getLogger(__name__)

class Critic(nn.Module):
    def __init__(self, input_dim, action_dim, hidden_dim=256, q_type="mlp", device="cpu"):
        super().__init__()
        self.q_type = q_type
        if q_type == "linear":
            self.q1 = nn.Linear(input_dim, action_dim)
            self.q2 = nn.Linear(input_dim, action_dim)
        elif q_type == "mlp":
            self.q1 = nn.Sequential(
                nn.Linear(input_dim, hidden_dim), nn.ReLU(),
                # nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
                nn.Linear(hidden_dim, action_dim),
            )
            self.q2 = nn.Sequential(
                nn.Linear(input_dim, hidden_dim), nn.ReLU(),
                # nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
                nn.Linear(hidden_dim, action_dim),
            )
        else:
            self.q1 = nn.Sequential(
                nn.Linear(input_dim, hidden_dim), nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
                nn.Linear(hidden_dim, action_dim),
            )
            self.q2 = nn.Sequential(
                nn.Linear(input_dim, hidden_dim), nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
                nn.Linear(hidden_dim, action_dim),
            )
        self.to(device)

    def forward(self, x):
        return self.q1(x), self.q2(x)


@dataclasses.dataclass
class DiscreteRLAgentConfig:
    # @package agent
    _target_: str = "url_benchmark.agent.encoder_rl.DiscreteGCRLAgent"
    name: str = "discrete_gcrl"
    obs_type: str = "features"  # or 'pixels'
    obs_shape: tp.Tuple[int, ...] = (0,)
    action_shape: tp.Tuple[int, ...] = (0,)
    device: str = "cuda"
    lr: float = 1e-4
    hidden_dim: int = 512
    num_expl_steps: int = 1000
    batch_size: int = 256
    update_every_steps: int = 2
    target_tau: float = 0.01
    use_tb: bool = False
    use_wandb: bool = False

    expl_eps: float = 0.2
    q_type: str = "mlp"  # "mlp" or "linear"



cs = ConfigStore.instance()
cs.store(group="agent", name="discrete_rl", node=DiscreteRLAgentConfig)



class DiscreteRLAgent:

    # pylint: disable=unused-argument
    def __init__(self, encoder_output_dim, cfg: DiscreteRLAgentConfig):
        self.cfg = cfg
        assert len(cfg.action_shape) == 1
        self.action_dim = cfg.action_shape[0]
        self.solved_meta: tp.Any = None
        self.encoder_output_dim = encoder_output_dim

        # models
        self.critic = Critic(self.encoder_output_dim, self.action_dim, cfg.hidden_dim, cfg.q_type, cfg.device)
        self.critic_target = Critic(self.encoder_output_dim, self.action_dim, cfg.hidden_dim, cfg.q_type, cfg.device)
        self.critic_target.load_state_dict(self.critic.state_dict())
        self.q_opt = torch.optim.Adam(self.critic.parameters(), lr=cfg.lr)

    def init_networks(self):
        self.critic = Critic(self.encoder_output_dim, self.action_dim, self.cfg.hidden_dim, self.cfg.q_type, self.cfg.device)
        self.critic_target = Critic(self.encoder_output_dim, self.action_dim, self.cfg.hidden_dim, self.cfg.q_type, self.cfg.device)
        self.critic_target.load_state_dict(self.critic.state_dict())
        self.q_opt = torch.optim.Adam(self.critic.parameters(), lr=self.cfg.lr)
        

    
    def load_encoder(self, encoder: nn.Module):
        """Replace the agent's encoder with the provided encoder module."""
        self.encoder = deepcopy(encoder)
        self.encoder.to(self.cfg.device)
        for param in self.encoder.parameters():
            param.requires_grad = False

    def train(self, training: bool = True) -> None:
        self.training = training
        for net in [self.encoder, self.critic]:
            net.train(training)

    def get_goal_meta(self, goal_array: np.ndarray) -> MetaDict:
        # Not used for general reward; kept for API compatibility
        return OrderedDict()

    def act(self, obs, meta, step, eval_mode) -> tp.Any:
        obs = torch.as_tensor(obs, device=self.cfg.device, dtype=torch.float32).unsqueeze(0)  # type: ignore
        del meta
        h = self.encoder(obs)
        Q1, Q2 = self.critic(h)
        Q = torch.min(Q1, Q2)
        action = Q.max(1)[1]

        if not eval_mode:
            if step < self.cfg.num_expl_steps:
                action = torch.randint_like(action, self.action_dim)
            else:
                action = torch.randint_like(action, self.action_dim) \
                    if np.random.rand() < self.cfg.expl_eps else action
        return action.item()


    def update_q(self,
                 obs: torch.Tensor,
                 action: torch.Tensor,
                 reward: torch.Tensor,
                 discount: torch.Tensor,
                 next_obs: torch.Tensor,
                 step: int) -> tp.Dict[str, float]:
        metrics: tp.Dict[str, float] = {}

        # Always encode observations before Q
        h = self.encoder(obs)
        with torch.no_grad():
            h_next = self.encoder(next_obs)

        # Q(s,a) on encoded
        q1, q2 = self.critic(h)
        # import pdb; pdb.set_trace()
        q1_a = q1.gather(1, action.long())
        q2_a = q2.gather(1, action.long())

        # Double Q-learning target
        with torch.no_grad():
            next_q1_target, next_q2_target = self.critic_target(h_next)
            next_q_online = torch.min(*self.critic(h_next))
            next_actions = next_q_online.argmax(dim=1, keepdim=True)
            next_q_target = torch.min(next_q1_target, next_q2_target).gather(1, next_actions)
            target = reward + discount * next_q_target

        loss1 = F.mse_loss(q1_a, target)
        loss2 = F.mse_loss(q2_a, target)
        loss = loss1 + loss2

        self.q_opt.zero_grad(set_to_none=True)
        loss.backward()
        self.q_opt.step()

        utils.soft_update_params(self.critic, self.critic_target, self.cfg.target_tau)

        metrics['critic_loss'] = loss.item()
        metrics['q1'] = q1_a.mean().item()
        metrics['q2'] = q2_a.mean().item()
        metrics['target'] = target.mean().item()
        metrics['q1_max'] = q1_a.max().item()
        metrics['q2_max'] = q2_a.max().item()

        return metrics


    def aug_and_encode(self, obs: torch.Tensor) -> torch.Tensor:
        obs = self.aug(obs)
        return self.encoder(obs)
    
    def sample_goals(self, idxs, batch_size, N, M, L, p_curgoal, p_trajgoal, p_randomgoal, geom_sample):
        """Sample goals for the given indices."""

        # Random goals.
        random_goal_idxs = np.random.randint(0, N, size=batch_size)

        # Goals from the same trajectory (excluding the current state, unless it is the final state).
        terminal_locs = np.arange(L - 1, N, L)
        final_state_idxs = terminal_locs[np.searchsorted(terminal_locs, idxs)]
        if geom_sample:
            # Geometric sampling.
            offsets = np.random.geometric(p=1 - self.cfg.train.discount, size=batch_size)  # in [1, inf)
            traj_goal_idxs = np.minimum(idxs + offsets, final_state_idxs)
        else:
            # Uniform sampling.
            distances = np.random.rand(batch_size)  # in [0, 1)
            traj_goal_idxs = np.round(
                (np.minimum(idxs + 1, final_state_idxs) * distances + final_state_idxs * (1 - distances))
            ).astype(int)
        if p_curgoal == 1.0:
            goal_idxs = idxs
        else:
            goal_idxs = np.where(
                np.random.rand(batch_size) < p_trajgoal / (1.0 - p_curgoal), traj_goal_idxs, random_goal_idxs
            )

            # Goals at the current state.
            goal_idxs = np.where(np.random.rand(batch_size) < p_curgoal, idxs, goal_idxs)

        return goal_idxs

    def update(self, replay_loader: ReplayBuffer, step: int, reward_fn: tp.Callable[[torch.Tensor], torch.Tensor]) -> tp.Dict[str, float]:
        metrics: tp.Dict[str, float] = {}

        if step % self.cfg.update_every_steps != 0:
            return metrics

        batch = replay_loader.sample(self.cfg.batch_size)
        batch = batch.to(self.cfg.device)

        obs = batch.obs
        action = batch.action.long()
        discount = batch.discount
        next_obs = batch.next_obs

        # Compute reward using provided reward function
        # reward_fn should take (obs) and return a tensor of shape (batch, 1) or (batch,)
        reward = reward_fn(next_obs.cpu()).to(self.cfg.device)

        metrics.update(self.update_q(obs=obs, action=action, reward=reward,
            discount=discount, next_obs=next_obs, step=step))

        return metrics

    def q_function(self, obs):
        obs = obs.to(self.cfg.device)
        h = self.encoder(obs)
        Q1, Q2 = self.critic(h)
        return torch.min(Q1, Q2)
    
    # def q_function_pos_neg(self, obs, goal, neg_goal):
    #     print('in q_function_pos_neg')
    #     h = self.encoder(obs)
    #     z = self.backward_net(goal) - self.backward_net(neg_goal)
    #     F1, F2 = self.forward_net(h, z)
    #     Q1, Q2 = [torch.einsum('sda, sd -> sa', Fi, z) for Fi in [F1, F2]]
    #     Q = torch.min(Q1, Q2)
    #     return Q
        
    def plot_q_function(self, work_dir, step, env, bf_action=None):
        state_list = env.get_state_list()
        print('in plot_q_function')
        # print(state_list)
        obs_list = [torch.tensor(env.get_obs_from_state(state)).unsqueeze(0) for state in state_list] # implement this function
        # print(obs_list)
        # print(len(state_list))
        obs_list = torch.cat(obs_list, dim=0).to(self.cfg.device)
        q_list = self.q_function(obs_list).detach()
        v_list = torch.max(q_list, dim=1)[0]
        # v_list = v_list
        a_list = torch.argmax(q_list, dim=1).cpu()
        # print(v_list, a_list)
        env.plot_v_function(work_dir, obs_list.cpu(), v_list, a_list, f"training_step_{step}_v_function") # write this function

        num_pos = 0
        num_neg = 0
        if bf_action is not None:
            for i in range(len(state_list)):
                print('State: ', state_list[i], ' | Optimal Action: ', bf_action[(state_list[i][1], state_list[i][0])], ' | Policy action: ', a_list[i].item())

                if a_list[i].item() in bf_action[(state_list[i][1], state_list[i][0])]:
                    num_pos += 1
                else:
                    num_neg += 1
            print('Positive: ', num_pos, ' | Negative: ', num_neg)
            return num_pos, num_neg
        
    # def plot_q_function_pos_neg(self, work_dir, step, env, goal, neg_goal, bf_action=None):
    #     state_list = env.get_state_list()
    #     print('in plot_q_pos_neg_function')
    #     # print(state_list)
    #     obs_list = [torch.tensor(env.get_obs_from_state(state)).unsqueeze(0) for state in state_list] # implement this function
    #     # print(obs_list)
    #     # print(len(state_list))
    #     obs_list = torch.cat(obs_list, dim=0).to(self.cfg.device)
    #     goal = torch.tensor(goal).unsqueeze(0).repeat(obs_list.shape[0], 1).to(self.cfg.device)
    #     neg_goal = torch.tensor(neg_goal).unsqueeze(0).repeat(obs_list.shape[0], 1).to(self.cfg.device)
    #     # print(obs_list.shape, goal.shape)
    #     q_list = self.q_function_pos_neg(obs_list, goal, neg_goal).detach()
    #     v_list = torch.max(q_list, dim=1)[0]
    #     # v_list = v_list
    #     a_list = torch.argmax(q_list, dim=1).cpu()
    #     # print(v_list, a_list)
    #     env.plot_v_function(work_dir, obs_list.cpu(), v_list, a_list, f"training_step_{step}_v_function") # write this function

    #     num_pos = 0
    #     num_neg = 0
    #     if bf_action is not None:
    #         for i in range(len(state_list)):
    #             print('State: ', state_list[i], ' | Optimal Action: ', bf_action[(state_list[i][1], state_list[i][0])], ' | Policy action: ', a_list[i].item())

    #             if a_list[i].item() in bf_action[(state_list[i][1], state_list[i][0])]:
    #                 num_pos += 1
    #             else:
    #                 num_neg += 1
    #         print('Positive: ', num_pos, ' | Negative: ', num_neg)
    #         return num_pos, num_neg





