import torch
from torch import nn
import numpy as np
import torch.nn.functional as F
from lpcmdp.algorithm.utils import get_f_div_fn

class Qnet(torch.nn.Module):
    def __init__(self, state_dim, hidden_dim, action_dim):
        super(Qnet, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, action_dim)
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        return self.fc2(x)
    
class DDQN:
    def __init__(self, 
                 state_dim, 
                 hidden_dim, 
                 action_dim, 
                 learning_rate, 
                 gamma, 
                 epsilon, 
                 target_update):
        self.action_dim = action_dim
        self.q_net = Qnet(state_dim, hidden_dim, self.action_dim)
        self.target_q_net = Qnet(state_dim, hidden_dim, self.action_dim)
        self.optimizer = torch.optim.Adam(self.q_net.parameters(), lr=learning_rate)
        self.gamma = gamma
        self.epsilon = epsilon
        self.target_update = target_update
        self.count = 0
        
    def take_action(self, state):
        if np.random.random() < self.epsilon:
            action = np.random.randint(self.action_dim)
        else:
            state = torch.tensor([state], dtype=torch.float)
            action = self.q_net(state).argmax().item()
        return action
        
    def max_q_value(self, state):
        state = torch.tensor([state], dtype=torch.float)
        return self.q_net(state).max().item()
    
    def update(self, transition_dict):
        states = torch.tensor(transition_dict['states'], dtype=torch.float).view(len(transition_dict['states']), -1)
        #print(states.shape)
        actions = torch.tensor(transition_dict['actions']).view(-1, 1)
        rewards = torch.tensor(transition_dict['rewards'], dtype=torch.float).view(-1, 1)
        next_states = torch.tensor(transition_dict['next_states'], dtype=torch.float).view(len(transition_dict['next_states']), -1)
        dones = torch.tensor(transition_dict['dones'], dtype=torch.float).view(-1, 1)
        q_values = self.q_net(states).gather(1, actions)
        

        max_action = self.q_net(next_states).max(1)[1].view(-1, 1)
        max_next_q_values = self.target_q_net(next_states).gather(1, max_action)
        
        q_targets = rewards + self.gamma * max_next_q_values * (1 - dones)

        dqn_loss = torch.mean(F.mse_loss(q_values, q_targets))
        self.optimizer.zero_grad()
        dqn_loss.backward()
        self.optimizer.step()
        
        if self.count % self.target_update == 0:
            self.target_q_net.load_state_dict(self.q_net.state_dict())
            
        self.count += 1


def mlp(sizes, activation, output_activation=nn.Identity):
    """
    Creates a multi-layer perceptron with the specified sizes and activations.

    Args:
        sizes (list): A list of integers specifying the size of each layer in the MLP.
        activation (nn.Module): The activation function to use for all layers except the output layer.
        output_activation (nn.Module): The activation function to use for the output layer. Defaults to nn.Identity.

    Returns:
        nn.Sequential: A PyTorch Sequential model representing the MLP.
    """

    layers = []
    for j in range(len(sizes) - 1):
        act = activation if j < len(sizes) - 2 else output_activation
        layer = nn.Linear(sizes[j], sizes[j + 1])
        layers += [layer, act()]
    return nn.Sequential(*layers)

# class QCritic(nn.Module):

#     def __init__(self, obs_dim, act_dim, hidden_sizes, activation, num_q=1, mean=0.01, var=0.01):
#         super().__init__()
#         self.q_net = mlp([obs_dim + act_dim] + list(hidden_sizes) + [1], nn.ReLU, output_activation=activation)
#         self.net_mean = mean
#         self.net_var = var
        
#         # for layer in self.q_net:
#         #     if isinstance(layer, torch.nn.Linear):
#         # #         # nn.init.xavier_normal_(layer.weight)
#         #         nn.init.normal_(layer.weight, self.net_mean, self.net_var)

#     def forward(self, obs, act=None):
#         data = obs if act is None else torch.cat([obs, act], dim=-1)
#         return self.q_net(data)
    
class MLPActor(nn.Module):
    """
    A MLP actor
        activation (Type[nn.Module]): The activation function to use between layers.
        act_limit (float, optional): The upper limit of the action space.
    """

    def __init__(self, obs_dim, act_dim, hidden_sizes, activation, act_limit=1):
        super().__init__()
        pi_sizes = [obs_dim] + list(hidden_sizes) + [act_dim]
        self.pi = mlp(pi_sizes, activation, nn.Tanh)
        self.act_limit = act_limit

    def forward(self, obs):
        # Return output from network scaled to action space limits.
        return self.act_limit * self.pi(obs)
    
class BC(nn.Module):
    """
    Behavior Cloning (BC)
    
    Args:
        state_dim (int): dimension of the state space.
        action_dim (int): dimension of the action space.
        max_action (float): Maximum action value.
        a_hidden_sizes (list, optional): List of integers specifying the sizes 
            of the layers in the actor network.
        episode_len (int, optional): Maximum length of an episode.
        device (str, optional): Device to run the model on (e.g. 'cpu' or 'cuda:0'). 
    """

    def __init__(self,
                 state_dim: int,
                 action_dim: int,
                 max_action: float,
                 a_hidden_sizes: list = [128, 128],
                 device: str = "cpu"):

        super().__init__()
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.max_action = max_action
        self.a_hidden_sizes = a_hidden_sizes
        self.device = device

        self.actor = MLPActor(self.state_dim, self.action_dim, self.a_hidden_sizes,
                              nn.ReLU, self.max_action).to(self.device)

    def actor_loss(self, observations, actions):
        pred_actions = self.actor(observations)
        loss_actor = F.mse_loss(pred_actions, actions)
        self.actor_optim.zero_grad()
        loss_actor.backward()
        self.actor_optim.step()
        stats_actor = {"loss/actor_loss": loss_actor.item()}
        return loss_actor, stats_actor

    def setup_optimizers(self, actor_lr):
        self.actor_optim = torch.optim.Adam(self.actor.parameters(), lr=actor_lr)

    def act(self, obs):
        '''
        Given a single obs, return the action.
        '''
        obs = torch.tensor(obs[None, ...], dtype=torch.float32).to(self.device)
        act = self.actor(obs)
        act = act.data.numpy() if self.device == "cpu" else act.data.cpu().numpy()
        return np.squeeze(act, axis=0)    



        
    


