import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal

class RNNEncoder(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim, z_dim):
        super(RNNEncoder, self).__init__()
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.hidden_dim = hidden_dim
        self.z_dim = z_dim
        
        self.input_fc = nn.Linear(state_dim + action_dim, hidden_dim)
        self.rnn = nn.GRU(hidden_dim, hidden_dim, batch_first=True)
        self.output_fc = nn.Linear(hidden_dim, z_dim * 2)
    
    def forward(self, trajectories):
        batch_size, seq_len, _ = trajectories.shape
        
        inputs = trajectories[:, :, :self.state_dim + self.action_dim]
        
        embedded = F.relu(self.input_fc(inputs))
        _, hidden = self.rnn(embedded)
        
        hidden = hidden.squeeze(0)
        params = self.output_fc(hidden)
        
        mean = params[:, :self.z_dim]
        log_std = params[:, self.z_dim:]
        std = torch.exp(log_std)
        
        return mean, std

class EnsembleDynamicsModel(nn.Module):
    def __init__(self, state_dim, action_dim, z_dim, hidden_dim, ensemble_size=3):
        super(EnsembleDynamicsModel, self).__init__()
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.z_dim = z_dim
        self.hidden_dim = hidden_dim
        self.ensemble_size = ensemble_size
        
        self.models = nn.ModuleList()
        for _ in range(ensemble_size):
            model = nn.Sequential(
                nn.Linear(state_dim + action_dim + z_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, state_dim * 2)
            )
            self.models.append(model)
    
    def forward(self, state, action, z, model_idx=None):
        if model_idx is not None:
            x = torch.cat([state, action, z], dim=-1)
            params = self.models[model_idx](x)
            mean = params[..., :self.state_dim]
            log_std = params[..., self.state_dim:]
            std = torch.exp(log_std)
            return mean, std
        else:
            means = []
            stds = []
            for model in self.models:
                x = torch.cat([state, action, z], dim=-1)
                params = model(x)
                mean = params[..., :self.state_dim]
                log_std = params[..., self.state_dim:]
                std = torch.exp(log_std)
                means.append(mean)
                stds.append(std)
            return torch.stack(means), torch.stack(stds)
    
    def compute_log_prob(self, state, action, next_state, z, model_idx=None):
        if model_idx is not None:
            mean, std = self.forward(state, action, z, model_idx)
            dist = Normal(mean, std)
            log_prob = dist.log_prob(next_state).sum(dim=-1)
            return log_prob
        else:
            log_probs = []
            for i in range(self.ensemble_size):
                mean, std = self.forward(state, action, z, i)
                dist = Normal(mean, std)
                log_prob = dist.log_prob(next_state).sum(dim=-1)
                log_probs.append(log_prob)
            return torch.stack(log_probs)

class ReplayBuffer:
    def __init__(self, state_dim, action_dim, z_dim=256, max_size=int(1e6)):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.z_dim = z_dim
        self.max_size = max_size
        self.ptr = 0
        self.size = 0
        
        self.states = np.zeros((max_size, state_dim))
        self.actions = np.zeros((max_size, action_dim))
        self.next_states = np.zeros((max_size, state_dim))
        self.rewards = np.zeros((max_size, 1))
        self.terminals = np.zeros((max_size, 1))
        self.trajectory_ids = np.zeros((max_size, 1), dtype=int)
        
        if z_dim is not None:
            self.z = np.zeros((max_size, z_dim))
        else:
            self.z = None
    
    def add(self, state, action, next_state, reward, terminal, trajectory_id, z=None):
        self.states[self.ptr] = state
        self.actions[self.ptr] = action
        self.next_states[self.ptr] = next_state
        self.rewards[self.ptr] = reward
        self.terminals[self.ptr] = terminal
        self.trajectory_ids[self.ptr] = trajectory_id
        
        if self.z_dim is not None and z is not None:
            self.z[self.ptr] = z
        
        self.ptr = (self.ptr + 1) % self.max_size
        self.size = min(self.size + 1, self.max_size)
    
    def add_trajectory(self, trajectory):
        trajectory_id = np.max(self.trajectory_ids) + 1 if self.size > 0 else 0
        for transition in trajectory:
            if len(transition) == 5:
                state, action, next_state, reward, terminal = transition
                z = None
            elif len(transition) == 6:
                state, action, next_state, reward, terminal, z = transition
            else:
                raise ValueError("Invalid transition length")
            self.add(state, action, next_state, reward, terminal, trajectory_id, z)
    
    def sample_trajectories(self, batch_size, seq_len):
        unique_trajectory_ids = np.unique(self.trajectory_ids[:self.size])
        sampled_ids = np.random.choice(unique_trajectory_ids, batch_size, replace=False)
        
        trajectories = []
        for traj_id in sampled_ids:
            traj_mask = self.trajectory_ids[:self.size].flatten() == traj_id
            traj_states = self.states[:self.size][traj_mask]
            traj_actions = self.actions[:self.size][traj_mask]
            traj_next_states = self.next_states[:self.size][traj_mask]
            
            if len(traj_states) < seq_len:
                continue
            
            start_idx = np.random.randint(0, len(traj_states) - seq_len + 1)
            end_idx = start_idx + seq_len
            
            traj = np.concatenate([
                traj_states[start_idx:end_idx],
                traj_actions[start_idx:end_idx],
                traj_next_states[start_idx:end_idx]
            ], axis=1)
            trajectories.append(traj)
        
        if len(trajectories) == 0:
            return None
        
        return np.array(trajectories)
    
    def sample(self, batch_size):
        ind = np.random.randint(0, self.size, size=batch_size)
        
        samples = (
            torch.FloatTensor(self.states[ind]),
            torch.FloatTensor(self.actions[ind]),
            torch.FloatTensor(self.next_states[ind]),
            torch.FloatTensor(self.rewards[ind]),
            torch.FloatTensor(self.terminals[ind])
        )
        
        if self.z_dim is not None:
            samples += (torch.FloatTensor(self.z[ind]),)
        
        return samples
    
    def relabel_with_z(self, z):
        if self.z_dim is None:
            self.z_dim = z.shape[1]
            self.z = np.zeros((self.max_size, self.z_dim))
        
        if len(z) == self.size:
            self.z[:self.size] = z
            
    def convert_dataset(self, dataset):
        for transition in dataset:
            if len(transition) == 5:
                state, action, next_state, reward, terminal = transition
                z = None
            elif len(transition) == 6:
                state, action, next_state, reward, terminal, z = transition
            else:
                raise ValueError("Invalid transition length")
            self.add(state, action, next_state, reward, terminal, trajectory_id, z)
            
    
    def convert_to_dataset(self):
        dataset = {
            'observations': self.states[:self.size],
            'actions': self.actions[:self.size],
            'next_observations': self.next_states[:self.size],
            'rewards': self.rewards[:self.size],
            'terminals': self.terminals[:self.size],
        }
        
        if self.z_dim is not None:
            dataset['z'] = self.z[:self.size]
        
        return dataset
    
    def get_trajectory(self, trajectory_id):
        traj_mask = self.trajectory_ids[:self.size].flatten() == trajectory_id
        result = {
            'states': self.states[:self.size][traj_mask],
            'actions': self.actions[:self.size][traj_mask],
            'next_states': self.next_states[:self.size][traj_mask],
            'rewards': self.rewards[:self.size][traj_mask],
            'terminals': self.terminals[:self.size][traj_mask]
        }
        
        if self.z_dim is not None:
            result['z'] = self.z[:self.size][traj_mask]
        
        return result
    
    def get_all_trajectories(self):
        unique_trajectory_ids = np.unique(self.trajectory_ids[:self.size])
        trajectories = []
        for traj_id in unique_trajectory_ids:
            trajectories.append(self.get_trajectory(traj_id))
        return trajectories
