# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pylint: disable=unused-import
import pdb
import copy
import math
import logging
import dataclasses
from collections import OrderedDict
import typing as tp

import numpy as np
from url_benchmark.agent.discrete_fb import DiscreteFBAgentConfig
import torch
from torch import nn
import torch.nn.functional as F
from hydra.core.config_store import ConfigStore

from url_benchmark import utils
# from url_benchmark import replay_buffer as rb
from url_benchmark.in_memory_replay_buffer import ReplayBuffer
from url_benchmark.dmc import TimeStep
from url_benchmark import goals as _goals
from .ddpg import MetaDict
from .fb_modules import IdentityMap
from .ddpg import Encoder
from .fb_modules import BackwardMap, mlp


logger = logging.getLogger(__name__)


from .fb_ddpg import FBDDPGAgentConfig


@dataclasses.dataclass
class DiscreteGCRLAgentConfig(FBDDPGAgentConfig):
    # @package agent
    _target_: str = "url_benchmark.agent.gcrl.DiscreteGCRLAgent"
    name: str = "discrete_gcrl"
    preprocess: bool = False
    expl_eps: float = 0.2
    boltzmann = True
    temp = 100
    obs_type: str = "state"  # to be specified later



cs = ConfigStore.instance()
cs.store(group="agent", name="discrete_gcrl", node=DiscreteGCRLAgentConfig)



class Critic(nn.Module):
    def __init__(self, input_dim, goal_dim, action_dim, hidden_dim=256, q_type="mlp", device="cpu"):
        super().__init__()
        self.q_type = q_type
        self.q1 = nn.Sequential(
            nn.Linear(input_dim+goal_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, action_dim),
        )
        self.q2 = nn.Sequential(
            nn.Linear(input_dim+goal_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, action_dim),
        )
        self.to(device)

    def forward(self, x, goal):
        x = torch.cat([x, goal], dim=-1)
        return self.q1(x), self.q2(x)
    
class DiscreteGCRLAgent:

    # pylint: disable=unused-argument
    def __init__(self,
                 **kwargs: tp.Any
                 ):
        cfg = DiscreteGCRLAgentConfig(**kwargs)
        self.cfg = cfg
        assert len(cfg.action_shape) == 1
        self.action_dim = cfg.action_shape[0]
        self.solved_meta: tp.Any = None

        # models
        if cfg.obs_type == 'pixels':
            self.aug: nn.Module = utils.RandomShiftsAug(pad=4)
            self.encoder: nn.Module = Encoder(cfg.obs_shape).to(cfg.device)
            self.obs_dim = self.encoder.repr_dim
        else:
            self.aug = nn.Identity()
            self.encoder = nn.Identity()
            self.obs_dim = cfg.obs_shape[0]
        if cfg.feature_dim < self.obs_dim:
            logger.warning(f"feature_dim {cfg.feature_dim} should not be smaller that obs_dim {self.obs_dim}")
        goal_dim = self.obs_dim
        if cfg.goal_space is not None:
            goal_dim = _goals.get_goal_space_dim(cfg.goal_space)

        self.critic = Critic(self.obs_dim, goal_dim, self.action_dim, cfg.hidden_dim).to(cfg.device)
        # build up the target network
        self.critic_target = Critic(self.obs_dim, goal_dim, self.action_dim, cfg.hidden_dim).to(cfg.device)
        # load the weights into the target networks
        self.critic_target.load_state_dict(self.critic.state_dict())
        # optimizers
        self.encoder_opt: tp.Optional[torch.optim.Adam] = None
        if cfg.obs_type == 'pixels':
            self.encoder_opt = torch.optim.Adam(self.encoder.parameters(), lr=cfg.lr)

        self.opt = torch.optim.Adam([{'params': self.critic.parameters()}],
                                       lr=cfg.lr)

        self.train()
        self.critic.train()

    def train(self, training: bool = True) -> None:
        self.training = training
        for net in [self.encoder, self.critic]:
            net.train(training)

    def init_from(self, other) -> None:
        # copy parameters over
        names = ["encoder"]
        if self.cfg.init_fb:
            names += ["critic", "critic_target"]
        for name in names:
            utils.hard_update_params(getattr(other, name), getattr(self, name))
        for key, val in self.__dict__.items():
            if isinstance(val, torch.optim.Optimizer):
                val.load_state_dict(copy.deepcopy(getattr(other, key).state_dict()))

    def get_goal_meta(self, goal_array: np.ndarray) -> MetaDict:
        desired_goal = torch.tensor(goal_array).unsqueeze(0).to(self.cfg.device)
        meta = OrderedDict()
        meta['z'] = desired_goal.squeeze(0).cpu().numpy()
        return meta

    def act(self, obs, meta, step, eval_mode) -> tp.Any:
        obs = torch.as_tensor(obs, device=self.cfg.device, dtype=torch.float32).unsqueeze(0)  # type: ignore
        h = self.encoder(obs)
        z = torch.as_tensor(meta['z'], device=self.cfg.device).unsqueeze(0)  # type: ignore
        Q1, Q2 = self.critic(h, z)
        Q = torch.min(Q1, Q2)
        action = Q.max(1)[1]

        if not eval_mode:
            if step < self.cfg.num_expl_steps:
                action = torch.randint_like(action, self.action_dim)
            else:
                action = torch.randint_like(action, self.action_dim) \
                    if np.random.rand() < self.cfg.expl_eps else action
        return action.item()


    def update_gcrl(
        self,
        obs: torch.Tensor,
        action: torch.Tensor,
        discount: torch.Tensor,
        next_obs: torch.Tensor,
        next_goal: torch.Tensor,
        step: int
    ) -> tp.Dict[str, float]:
        metrics: tp.Dict[str, float] = {} 
        # compute target successor measure

        with torch.no_grad():
            # implicit_reward = (torch.matmul(B, inv_cov) * z).sum(dim=1)  # batch_size
            next_Q1, next_Q2 = self.critic_target(self.encoder(next_obs), next_goal)
            next_Q = torch.min(next_Q1, next_Q2)
            max_next_Q = next_Q.max(1)[0]
            implicit_reward = torch.all(next_obs == next_goal, dim=-1)
            terminals = torch.all(next_obs == next_goal, dim=-1)
            implicit_reward = implicit_reward.float() - 1.0
            target_Q = implicit_reward.detach() + discount.squeeze(1) * max_next_Q * (~terminals)  # batch_size

        # if step % 100 == 0:
        #     import pdb; pdb.set_trace()
        Q1, Q2 = self.critic(self.encoder(obs), next_goal)
        Q1 = Q1.gather(1, action).squeeze(1)
        Q2 = Q2.gather(1, action).squeeze(1)
        # print('Q1, Q2, target_Q: ', Q1.mean().item(), Q2.mean().item(), target_Q.mean().item())
        q_loss = F.mse_loss(Q1, target_Q) + F.mse_loss(Q2, target_Q)

        # if self.cfg.use_tb or self.cfg.use_wandb or self.cfg.use_hiplog:
        metrics['q_loss'] = q_loss.item()
        metrics['Q1'] = Q1.mean().item()
        metrics['Q2'] = Q2.mean().item()
        metrics['target_Q'] = target_Q.mean().item()

        # optimize FB
        if self.encoder_opt is not None:
            self.encoder_opt.zero_grad(set_to_none=True)
        self.opt.zero_grad(set_to_none=True)
        q_loss.backward()
        self.opt.step()
        if self.encoder_opt is not None:
            self.encoder_opt.step()
        return metrics


    def aug_and_encode(self, obs: torch.Tensor) -> torch.Tensor:
        obs = self.aug(obs)
        return self.encoder(obs)

    def update(self, replay_loader: ReplayBuffer, step: int) -> tp.Dict[str, float]:
        metrics: tp.Dict[str, float] = {}

        if step % self.cfg.update_every_steps != 0:
            return metrics

        batch = replay_loader.sample(self.cfg.batch_size)
        batch = batch.to(self.cfg.device)

        # pdb.set_trace()
        obs = batch.obs
        action = batch.action.type(torch.int64)
        discount = batch.discount
        next_obs = next_goal = batch.next_obs
        if self.cfg.goal_space is not None:
            assert batch.next_goal is not None
            next_goal = batch.next_goal

        # Create goals with a mix of next obs and random goals (HER)
        her_ratio = 0.5  # Ratio of HER goals
        batch_size = self.cfg.batch_size
        random_indices = torch.randint(0, batch_size, (batch_size,), device=self.cfg.device)
        random_goals = next_goal[random_indices]  # Sample random goals from the batch
        her_goals = next_obs  # Use next observations as HER goals

        # Mix HER goals and random goals
        goal_mask = torch.rand(batch_size, device=self.cfg.device) < her_ratio
        next_goal = torch.where(goal_mask.unsqueeze(-1), her_goals, random_goals)

        # next_goal = torch.zeros_like(next_goal)
        # next_goal[:, 14] = 1.0

        metrics.update(self.update_gcrl(obs=obs, action=action, discount=discount,
                                      next_obs=next_obs, next_goal=next_goal, step=step))
        
        if step % 1000 == 0:
            print('Step: ', step, 'q_loss: ', metrics['q_loss'], ' | Q1: ', metrics['Q1'], ' | Q2: ', metrics['Q2'], ' | target_Q: ', metrics['target_Q'])

        # update critic target
        utils.soft_update_params(self.critic, self.critic_target,
                                 self.cfg.fb_target_tau)
        return metrics
    
    def inference(self, replay_loader: ReplayBuffer, inf_logger, pos_goal, neg_goal, reward_fn) -> torch.Tensor:
        return {}

    def q_function_inference(self, obs, goal, neg_goal, z):
        batch_size = obs.shape[0]
        goal = torch.tensor(goal).repeat(batch_size, 1).to(self.cfg.device)
        h = self.encoder(obs)
        Q1, Q2 = self.critic(h, goal)
        Q = torch.min(Q1, Q2)
        return Q
    
    def q_function_pos_neg(self, obs, goal, neg_goal):
        print('in q_function_pos_neg')
        h = self.encoder(obs)
        z = self.backward_net(goal) - self.backward_net(neg_goal)
        F1, F2 = self.forward_net(h, z)
        Q1, Q2 = [torch.einsum('sda, sd -> sa', Fi, z) for Fi in [F1, F2]]
        Q = torch.min(Q1, Q2)
        return Q
        
    def plot_q_function(self, work_dir, step, env, goal, bf_action=None):
        state_list = env.get_state_list()
        print('in plot_q_function')
        # print(state_list)
        obs_list = [torch.tensor(env.get_obs_from_state(state)).unsqueeze(0) for state in state_list] # implement this function
        # print(obs_list)
        # print(len(state_list))
        obs_list = torch.cat(obs_list, dim=0).to(self.cfg.device)
        goal = torch.tensor(goal).unsqueeze(0).repeat(obs_list.shape[0], 1).to(self.cfg.device)
        # print(obs_list.shape, goal.shape)
        q_list = self.q_function(obs_list, goal).detach()
        v_list = torch.max(q_list, dim=1)[0]
        # v_list = v_list
        a_list = torch.argmax(q_list, dim=1).cpu()
        # print(v_list, a_list)
        env.plot_v_function(work_dir, obs_list.cpu(), v_list, a_list, f"training_step_{step}_v_function") # write this function

        num_pos = 0
        num_neg = 0
        if bf_action is not None:
            for i in range(len(state_list)):
                print('State: ', state_list[i], ' | Optimal Action: ', bf_action[(state_list[i][1], state_list[i][0])], ' | Policy action: ', a_list[i].item())

                if a_list[i].item() in bf_action[(state_list[i][1], state_list[i][0])]:
                    num_pos += 1
                else:
                    num_neg += 1
            print('Positive: ', num_pos, ' | Negative: ', num_neg)
            return num_pos, num_neg
        
    def plot_q_function_pos_neg(self, work_dir, step, env, goal, neg_goal, bf_action=None):
        state_list = env.get_state_list()
        print('in plot_q_pos_neg_function')
        # print(state_list)
        obs_list = [torch.tensor(env.get_obs_from_state(state)).unsqueeze(0) for state in state_list] # implement this function
        # print(obs_list)
        # print(len(state_list))
        obs_list = torch.cat(obs_list, dim=0).to(self.cfg.device)
        goal = torch.tensor(goal).unsqueeze(0).repeat(obs_list.shape[0], 1).to(self.cfg.device)
        neg_goal = torch.tensor(neg_goal).unsqueeze(0).repeat(obs_list.shape[0], 1).to(self.cfg.device)
        # print(obs_list.shape, goal.shape)
        q_list = self.q_function_pos_neg(obs_list, goal, neg_goal).detach()
        v_list = torch.max(q_list, dim=1)[0]
        # v_list = v_list
        a_list = torch.argmax(q_list, dim=1).cpu()
        # print(v_list, a_list)
        env.plot_v_function(work_dir, obs_list.cpu(), v_list, a_list, f"training_step_{step}_v_function") # write this function

        num_pos = 0
        num_neg = 0
        if bf_action is not None:
            for i in range(len(state_list)):
                print('State: ', state_list[i], ' | Optimal Action: ', bf_action[(state_list[i][1], state_list[i][0])], ' | Policy action: ', a_list[i].item())

                if a_list[i].item() in bf_action[(state_list[i][1], state_list[i][0])]:
                    num_pos += 1
                else:
                    num_neg += 1
            print('Positive: ', num_pos, ' | Negative: ', num_neg)
            return num_pos, num_neg





