import torch
from torch import nn
from torch.distributions import Categorical, Normal
import numpy as np

def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer

class QNetwork:
    def __init__(self):
        pass

class LSTM_Agent(nn.Module):
    def __init__(self, envs, args):
        super().__init__()
        self.num_frames = args.num_frames
        self.depth_height = 84
        self.depth_width = 84
        
        # Calculate the exact size of the input
        self.input_size = self.num_frames * self.depth_height * self.depth_width

        # CNN for depth processing (stacked frames only)
        self.cnn = nn.Sequential(
            layer_init(nn.Conv2d(self.num_frames, 32, 8, stride=4)),
            nn.ReLU(),
            layer_init(nn.Conv2d(32, 64, 4, stride=2)),
            nn.ReLU(),
            layer_init(nn.Conv2d(64, 64, 3, stride=1)),
            nn.ReLU(),
            nn.Flatten(),
        )

        # Determine the output size of the CNN for a dummy input.
        with torch.no_grad():
            dummy_depth = torch.zeros(1, self.num_frames, self.depth_height, self.depth_width)
            cnn_output_size = self.cnn(dummy_depth).shape[1]
            print(cnn_output_size)

        self.lstm = nn.LSTM(cnn_output_size, 512)
        for name, param in self.lstm.named_parameters():
            if "bias" in name:
                nn.init.constant_(param, 0)
            elif "weight" in name:
                nn.init.orthogonal_(param, 1.0)

        # Critic head: accepts only the CNN depth features.
        self.critic_head = nn.Sequential(
            layer_init(nn.Linear(512, 256)),
            nn.Tanh(),
            layer_init(nn.Linear(256, 256)),
            nn.Tanh(),
            layer_init(nn.Linear(256, 1), std=1.0),
        )

        # Actor head: accepts only the CNN depth features.
        self.actor_head = nn.Sequential(
            layer_init(nn.Linear(512, 256)),
            nn.Tanh(),
            layer_init(nn.Linear(256, 256)),
            nn.Tanh(),
            layer_init(nn.Linear(256, 2), std=0.01),
        )

        self.actor_logstd = nn.Parameter(torch.zeros(1, np.prod(envs.single_action_space.shape)))

    def process_obs(self, x):
        """
        Process the observation containing only the stacked depth frames.
        The input `x` is expected to be a flattened tensor of size (batch, num_frames * depth_height * depth_width).
        """
        # Check if the input size matches what we expect
        #print(x.shape)
        if x.shape[1] != self.input_size:
            raise ValueError(f"Input size mismatch: got {x.shape[1]}, expected {self.input_size}")
            
        # Reshape x to (batch, num_frames, depth_height, depth_width)
        depth = x.reshape(-1, self.num_frames, self.depth_height, self.depth_width)
        depth_features = self.cnn(depth)
        #print(depth_features.shape)
        return depth_features

    def get_states(self, x, lstm_state, done):
        hidden = self.process_obs(x)

        # LSTM logic
        batch_size = lstm_state[0].shape[1]
        hidden = hidden.reshape((-1, batch_size, self.lstm.input_size))
        done = done.reshape((-1, batch_size)).long()
        new_hidden = []
        #print(done)
        for h, d in zip(hidden, done):
            #print(d)
            h, lstm_state = self.lstm(
                h.unsqueeze(0),
                (
                    (~d).view(1, -1, 1) * lstm_state[0],
                    (~d).view(1, -1, 1) * lstm_state[1],
                ),
            )
            new_hidden += [h]
        new_hidden = torch.flatten(torch.cat(new_hidden), 0, 1)
        return new_hidden, lstm_state

    def get_value(self, x, lstm_state, done):
        hidden, _ = self.get_states(x, lstm_state, done)
        return self.critic_head(hidden)
    
    def get_mean_std(self, x, lstm_state, done):
        hidden, lstm_state = self.get_states(x, lstm_state, done)
        value = self.critic_head(hidden)
        
        action_mean = self.actor_head(hidden)
        action_logstd = self.actor_logstd.expand_as(action_mean)
        action_std = torch.exp(action_logstd)
        probs = Normal(action_mean, action_std)
        return action_mean, action_std
        

    def get_action_and_value(self, x, lstm_state, done, action=None):
        hidden, lstm_state = self.get_states(x, lstm_state, done)
        value = self.critic_head(hidden)

        action_mean = self.actor_head(hidden)
        action_logstd = self.actor_logstd.expand_as(action_mean)
        action_std = torch.exp(action_logstd)
        probs = Normal(action_mean, action_std)

        if action is None:
            action = probs.sample()

        return (
            action,
            probs.log_prob(action).sum(1),
            probs.entropy().sum(1),
            value,
            lstm_state
        )

class Agent_ANN(nn.Module):
    def __init__(self, envs):
        super().__init__()
        self.critic = nn.Sequential(
            layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 1), std=1.0),
        )
        self.actor_mean = nn.Sequential(
            layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, np.prod(envs.single_action_space.shape)), std=0.01),
        )
        self.actor_logstd = nn.Parameter(torch.zeros(1, np.prod(envs.single_action_space.shape)))

    def get_value(self, x):
        return self.critic(x)

    def get_action_and_value(self, x, action=None):
        action_mean = self.actor_mean(x)
        action_logstd = self.actor_logstd.expand_as(action_mean)
        action_std = torch.exp(action_logstd)
        probs = Normal(action_mean, action_std)
        if action is None:
            action = probs.sample()
        return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.critic(x)

class Agent(nn.Module):
    def __init__(self, envs, args):
        super().__init__()

        # Make sure these match the environment's dimensions
        self.num_frames = args.num_frames
        self.depth_height = 84
        self.depth_width = 84
        
        # Calculate the exact size of the input
        self.input_size = self.num_frames * self.depth_height * self.depth_width

        # CNN for depth processing (stacked frames only)
        self.cnn = nn.Sequential(
            layer_init(nn.Conv2d(self.num_frames, 32, 8, stride=4)),
            nn.ReLU(),
            #nn.MaxPool2d(kernel_size=2, stride=2),
            layer_init(nn.Conv2d(32, 64, 4, stride=2)),
            nn.ReLU(),
            #nn.MaxPool2d(kernel_size=2, stride=2),
            layer_init(nn.Conv2d(64, 64, 3, stride=1)),
            nn.ReLU(),
            #nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Flatten(),
        )
        # self.cnn = nn.Sequential(
        #     nn.Conv2d(self.num_frames, 16, 8),
        #     nn.ReLU(),
        #     nn.MaxPool2d(2,2),
        #     nn.Conv2d(16, 32, 5),
        #     nn.ReLU(),
        #     nn.MaxPool2d(2,2),
        #     # nn.Conv2d(32, 32, 5),
        #     # nn.ReLU(),
        #     # nn.MaxPool2d(2,2),
        #     nn.Flatten()
        # )

        # Determine the output size of the CNN for a dummy input.
        with torch.no_grad():
            dummy_depth = torch.zeros(1, self.num_frames, self.depth_height, self.depth_width)
            cnn_output_size = self.cnn(dummy_depth).shape[1]
            print(cnn_output_size)

        # Critic head: accepts only the CNN depth features.
        self.critic_head = nn.Sequential(
            layer_init(nn.Linear(cnn_output_size, 1024)),
            nn.Tanh(),
            layer_init(nn.Linear(1024, 256)),
            nn.Tanh(),
            layer_init(nn.Linear(256, 1), std=1.0),
        )

        # Actor head: accepts only the CNN depth features.
        self.actor_head = nn.Sequential(
            layer_init(nn.Linear(cnn_output_size, 1024)),
            nn.Tanh(),
            layer_init(nn.Linear(1024, 256)),
            nn.Tanh(),
            layer_init(nn.Linear(256, 2), std=0.01),
        )

        self.actor_logstd = nn.Parameter(torch.zeros(1, np.prod(envs.single_action_space.shape)))

    def process_obs(self, x):
        """
        Process the observation containing only the stacked depth frames.
        The input `x` is expected to be a flattened tensor of size (batch, num_frames * depth_height * depth_width).
        """
        # Check if the input size matches what we expect
        if x.shape[1] != self.input_size:
            raise ValueError(f"Input size mismatch: got {x.shape[1]}, expected {self.input_size}")
            
        # Reshape x to (batch, num_frames, depth_height, depth_width)
        depth = x.reshape(-1, self.num_frames, self.depth_height, self.depth_width)
        depth_features = self.cnn(depth)
        return depth_features

    def get_value(self, x):
        depth_features = self.process_obs(x)
        return self.critic_head(depth_features)
    
    def get_mean_std(self, x):
        depth_features = self.process_obs(x)
        action_mean = self.actor_head(depth_features)
        action_logstd = self.actor_logstd.expand_as(action_mean)
        action_std = torch.exp(action_logstd)
        return action_mean, action_std

    def get_action_and_value(self, x, action=None):
        depth_features = self.process_obs(x)
        action_mean = self.actor_head(depth_features)
        action_logstd = self.actor_logstd.expand_as(action_mean)
        action_std = torch.exp(action_logstd)
        probs = Normal(action_mean, action_std)

        if action is None:
            action = probs.sample()

        return (
            action,
            probs.log_prob(action).sum(1),
            probs.entropy().sum(1),
            self.get_value(x),
        )

