import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np

class RLController(nn.Module):
    def __init__(self, input_dim, search_space, hidden_size=64, device='cpu'):
        """
        RL controller that generates autoencoder architectures
        
        Args:
            input_dim: Input dimension of the data
            search_space: AutoencoderSearchSpace instance
            hidden_size: Hidden size of the controller LSTM
            device: Device to put the model on ('cpu', 'cuda:0', etc.)
        """
        super(RLController, self).__init__()
        self.search_space = search_space
        self.hidden_size = hidden_size
        self.input_dim = input_dim
        self.device = device
        
        # LSTM controller
        self.lstm = nn.LSTMCell(self.hidden_size, self.hidden_size)
        
        # Embedding for previous decisions
        self.embedding = nn.Embedding(20, self.hidden_size)  # 20 is a large enough vocab size
        
        # Prediction heads for different architectural parameters
        self.depth_head = nn.Linear(self.hidden_size, self.search_space.max_decoder_depth - self.search_space.min_decoder_depth + 1)
        self.width_head = nn.Linear(self.hidden_size, int(np.log2(self.search_space.max_width)) - int(np.log2(self.search_space.min_width)) + 1)
        self.latent_dim_head = nn.Linear(self.hidden_size, int(np.log2(self.search_space.max_latent_dim)) - int(np.log2(self.search_space.min_latent_dim)) + 1)
        self.l1_head = nn.Linear(self.hidden_size, 10)  # Discretize L1 range into 10 bins
        self.skip_head = nn.Linear(self.hidden_size, 2)  # Binary decision for skip connection
        
        # Move model to device
        self.to(device)
        
        # Optimizer
        self.optimizer = optim.Adam(self.parameters(), lr=0.001)
        
        # Store log probabilities and rewards for policy gradient
        self.log_probs = []
        self.rewards = []
        
    def _reset_state(self, batch_size=1):
        """Reset controller state"""
        self.h = torch.zeros(batch_size, self.hidden_size, device=self.device)
        self.c = torch.zeros(batch_size, self.hidden_size, device=self.device)
        
    def sample_architecture(self):
        """Sample an architecture from the controller policy"""
        self._reset_state()
        self.log_probs = []
        
        # Sample depth
        depth_logits = self.depth_head(self.h)
        depth_probs = F.softmax(depth_logits, dim=-1)
        depth_dist = torch.distributions.Categorical(depth_probs)
        depth_idx = depth_dist.sample()
        depth = depth_idx.item() + self.search_space.min_decoder_depth
        self.log_probs.append(depth_dist.log_prob(depth_idx))
        
        # Update LSTM state
        depth_emb = self.embedding(depth_idx)
        self.h, self.c = self.lstm(depth_emb, (self.h, self.c))
        
        # Sample width for each layer
        width = []
        for i in range(depth):
            width_logits = self.width_head(self.h)
            width_probs = F.softmax(width_logits, dim=-1)
            width_dist = torch.distributions.Categorical(width_probs)
            width_idx = width_dist.sample()
            width_val = 2 ** (width_idx.item() + int(np.log2(self.search_space.min_width)))
            width.append(width_val)
            self.log_probs.append(width_dist.log_prob(width_idx))
            
            # Update LSTM state
            width_emb = self.embedding(width_idx)
            self.h, self.c = self.lstm(width_emb, (self.h, self.c))
            
        # Sample latent dimension
        latent_logits = self.latent_dim_head(self.h)
        latent_probs = F.softmax(latent_logits, dim=-1)
        latent_dist = torch.distributions.Categorical(latent_probs)
        latent_idx = latent_dist.sample()
        latent_dim = 2 ** (latent_idx.item() + int(np.log2(self.search_space.min_latent_dim)))
        self.log_probs.append(latent_dist.log_prob(latent_idx))
        
        # Update LSTM state
        latent_emb = self.embedding(latent_idx)
        self.h, self.c = self.lstm(latent_emb, (self.h, self.c))
        
        # Sample L1 regularization strength
        l1_logits = self.l1_head(self.h)
        l1_probs = F.softmax(l1_logits, dim=-1)
        l1_dist = torch.distributions.Categorical(l1_probs)
        l1_idx = l1_dist.sample()
        l1_min, l1_max = self.search_space.l1_min, self.search_space.l1_max
        l1_weight = l1_min + (l1_max - l1_min) * (l1_idx.item() / 9.0)  # 10 bins from 0 to 9
        self.log_probs.append(l1_dist.log_prob(l1_idx))
        
        # Update LSTM state
        l1_emb = self.embedding(l1_idx)
        self.h, self.c = self.lstm(l1_emb, (self.h, self.c))
        
        # Sample skip connections
        skip_connections = []
        possible_connections = []
        
        # From latent to any decoder layer (except first which is directly connected)
        for to_idx in range(1, depth + 1):
            possible_connections.append((-1, to_idx))
            
        # From encoder layers to decoder layers
        for from_idx in range(depth):
            for to_idx in range(depth + 1):
                possible_connections.append((from_idx, to_idx))
        
        # Sample each possible skip connection
        for from_idx, to_idx in possible_connections:
            skip_logits = self.skip_head(self.h)
            skip_probs = F.softmax(skip_logits, dim=-1)
            skip_dist = torch.distributions.Categorical(skip_probs)
            skip_decision = skip_dist.sample()
            if skip_decision.item() == 1:
                skip_connections.append((from_idx, to_idx))
            self.log_probs.append(skip_dist.log_prob(skip_decision))
            
            # Update LSTM state
            skip_emb = self.embedding(skip_decision)
            self.h, self.c = self.lstm(skip_emb, (self.h, self.c))
        skip_connections = []
        
        architecture = {
            'decoder_depth': depth,
            'decoder_width': width,
            'latent_dim': latent_dim,
            'l1_weight': l1_weight,
            'skip_connections': skip_connections,
            'encoder_depth': self.search_space.encoder_architecture['encoder_depth'],
            'encoder_width': self.search_space.encoder_architecture['encoder_width']
        }
        
        return architecture
    
    def update_policy(self, rewards):
        """Update controller policy using REINFORCE"""
        self.rewards.extend(rewards)
        
        # Normalize rewards for stability
        if len(self.rewards) > 1:
            rewards = torch.tensor(self.rewards, device=self.device)
            rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-9)
        else:
            rewards = torch.tensor(self.rewards, device=self.device)
        
        # Calculate policy gradient loss
        policy_loss = []
        for log_prob, reward in zip(self.log_probs, rewards):
            policy_loss.append(-log_prob * reward)
        
        policy_loss = torch.stack(policy_loss).sum()
        
        # Update policy
        self.optimizer.zero_grad()
        policy_loss.backward()
        self.optimizer.step()
        
        # Clear saved rewards and log probs
        self.log_probs = []
        self.rewards = []
