import numpy as np
import torch
from common.sac import ReplayBuffer, ReplayBufferNA

def bold(message):
  return "\033[1m" + message + "\033[0m"

def get_transitions(target_state_buffer, expert_action_trajs, include_na=False):
    assert len(target_state_buffer) == len(expert_action_trajs)
    obs = []
    actions = []
    if include_na:
      actions_2 = []
      a2_valid = []
    obs2 = []
    dones = []
    for idx in range(len(target_state_buffer)):
        obs.append(target_state_buffer[idx][:-1]) # s1 -- sT-1
        actions.append(expert_action_trajs[idx][1:]) # a1 -- aT-1
        if include_na:
          actions_2.append(np.concatenate([expert_action_trajs[idx][2:], np.zeros_like(expert_action_trajs[idx][0:1])], axis=0))
          a2_valid.append(np.concatenate([np.ones([len(target_state_buffer[idx])-2], dtype=bool), np.zeros([1], dtype=bool)], axis=0))
        obs2.append(target_state_buffer[idx][1:]) # s2 -- sT
        dones.append(np.zeros([len(target_state_buffer[idx])-1]))
        assert len(obs[-1]) == len(dones[-1])


    obs = np.concatenate(obs, axis=0)
    actions = np.concatenate(actions, axis=0)
    obs2 = np.concatenate(obs2, axis=0)
    dones = np.concatenate(dones, axis=0)
    if include_na:
      actions_2 = np.concatenate(actions_2, axis=0)
      a2_valid = np.concatenate(a2_valid, axis=0)
      return (obs, actions, obs2, dones, actions_2, a2_valid)
    return (obs, actions, obs2, dones)

def combined_shape(length, shape=None):
    if shape is None:
        return (length,)
    return (length, shape) if np.isscalar(shape) else (length, *shape)

def count_vars(module):
    return sum([np.prod(p.shape) for p in module.parameters()])

class ExpertReplayBuffer(ReplayBuffer):
    """
    A simple FIFO experience replay buffer for SAC agents.
    """

    def __init__(self, 
                 obs_dim, 
                 act_dim, 
                 device=torch.device('cpu'), 
                 size=int(1e6), 
                 erb_expert_samples=None, 
                 erb_expert_action_trajs=None, 
                 reward_fn=None, 
                 ratio=0.5, 
                 recalculate=True, 
                 return_is_expert=False,
                 na=False):
        super().__init__(obs_dim=obs_dim, act_dim=act_dim, device=device, size=size)

        if na:
          print(bold("Using expert's next actions in order to refine estimate of V*"))
          (obs, actions, obs2, dones, actions_2, a2_valid) = get_transitions(erb_expert_samples, erb_expert_action_trajs, include_na=True)
          self.expert_replay_buffer = ReplayBufferNA(obs_dim=obs_dim, act_dim=act_dim, device=device, size=size)
        else:
          (obs, actions, obs2, dones) = get_transitions(erb_expert_samples, erb_expert_action_trajs, include_na=False)
          self.expert_replay_buffer = ReplayBuffer(obs_dim=obs_dim, act_dim=act_dim, device=device, size=size)
          
        rew = np.zeros_like(dones)

        if na:
          self.expert_replay_buffer.store_batch(obs=obs,
                                                act=actions,
                                                rew=rew,
                                                next_obs=obs2,
                                                done=dones,
                                                act2=actions_2,
                                                a2_valid=a2_valid)
        else:
          self.expert_replay_buffer.store_batch(obs=obs,
                                                act=actions,
                                                rew=rew,
                                                next_obs=obs2,
                                                done=dones)
        
        self.reward_fn = reward_fn
        self.ratio = ratio
        self.recalculate = recalculate
        self.global_steps = 0
        print("ratio", ratio, "recalculate", recalculate)

        self.return_is_expert = return_is_expert
        print(f"return_is_expert = {self.return_is_expert}")

        self.na = na
        print(f"na = {na}")

    def sample_batch(self, batch_size=32):
        num_expert = int(batch_size * self.ratio)
        num_learner = batch_size - num_expert
        learner_batch = super().sample_batch(batch_size=num_learner)
        expert_batch = self.expert_replay_buffer.sample_batch(batch_size=num_expert)

        if self.recalculate:
          learner_rewards = torch.FloatTensor(self.reward_fn(learner_batch["obs"])).to(self.device)
        else:
          learner_rewards = learner_batch["rew"]

        expert_actions = expert_batch["act"]
        expert_rewards = torch.FloatTensor(self.reward_fn(expert_batch["obs"])).to(self.device)

        batch = {key: torch.cat([learner_batch[key], expert_batch[key]], dim=0) for key in learner_batch}

        batch["rew"] = torch.cat([learner_rewards, expert_rewards], dim=0)
        batch["act"] = torch.cat([learner_batch["act"], expert_actions], dim=0)
        if self.return_is_expert:
          batch["is_expert"] = torch.cat([torch.zeros_like(learner_batch["done"]), torch.ones_like(expert_batch["done"])], axis=0)

        if self.na:
          batch["act2"] = expert_batch["act2"]
          batch["a2_valid"] = expert_batch["a2_valid"]

        return batch