import torch
import numpy as np
from src.nas.rl.controller import RLController
from src.models.autoencoder import Autoencoder
from src.utils.training import train_autoencoder, evaluate_autoencoder
from src.utils.objective_functions import calculate_fitness_score, count_parameters

class RLSearch:
    def __init__(self, search_space, train_data, val_data, max_iterations=100, 
                 fitness_type="mdl", precision_bits=7):
        """
        Neural Architecture Search using Reinforcement Learning
        
        Args:
            search_space: AutoencoderSearchSpace instance
            train_data: Training data as (features, targets)
            val_data: Validation data as (features, targets)
            max_iterations: Maximum number of search iterations
            fitness_type: Type of fitness function ('negative_loss' or 'mdl')
            precision_bits: Number of bits per parameter for MDL calculation
        """
        self.search_space = search_space
        self.train_data = train_data
        self.val_data = val_data
        self.max_iterations = max_iterations
        self.input_dim = train_data.features.shape[1]
        self.n_samples = train_data.features.shape[0]
        self.fitness_type = fitness_type
        self.precision_bits = precision_bits
        
        # History of architectures and their performances
        self.history = []
        
        # The controller will be initialized in the search method with the proper device
        self.controller = None
        
    def search(self, num_epochs_per_arch=10, device='cpu'):
        """Execute neural architecture search"""
        # Initialize controller with the proper device
        self.controller = RLController(self.input_dim, self.search_space, device=device)
        
        best_arch = None
        best_score = float('-inf')  # Higher score is better
        
        for iteration in range(self.max_iterations):
            print(f"Search Iteration {iteration+1}/{self.max_iterations}")
            
            # Sample architecture from controller
            arch_config = self.controller.sample_architecture()
            print(f"Sampled Architecture: {arch_config}")
            
            # Create and train autoencoder
            model = Autoencoder(self.input_dim, arch_config).to(device)
            train_loss = train_autoencoder(model, self.train_data, num_epochs=num_epochs_per_arch, device=device)
            val_loss = evaluate_autoencoder(model, self.val_data, device=device)
            
            # Calculate decoder minimum capacity (Nn - Nm from Schuster and Krogh 2021)
            decoder_min_capacity = self.n_samples * (self.input_dim - arch_config['latent_dim'])
            print(f"Decoder minimum capacity: {decoder_min_capacity}")
            
            # Calculate fitness score (higher is better)
            fitness = calculate_fitness_score(
                model,
                val_loss,
                decoder_min_capacity,
                fitness_type=self.fitness_type,
                precision_bits=self.precision_bits
            )
            
            # Use fitness as reward (higher reward is better)
            reward = fitness.cpu().item()
            
            # Update controller policy
            self.controller.update_policy([reward])
            
            # Record results
            result = {
                'iteration': iteration,
                'architecture': arch_config,
                'train_loss': train_loss,
                'val_loss': val_loss,
                'reward': reward
            }
            
            # Add MDL score if using MDL fitness
            if self.fitness_type == "mdl":
                mdl_score = -reward  # MDL is negative of reward
                result['mdl_score'] = mdl_score
                print(f"Train Loss: {train_loss:.6f}, Validation Loss: {val_loss:.6f}, "
                      f"Parameters: {count_parameters(model)}, "
                      f"MDL Score: {mdl_score:.2f}")
            else:
                print(f"Train Loss: {train_loss:.6f}, Validation Loss: {val_loss:.6f}")
            
            self.history.append(result)
            
            # Update best architecture
            if reward > best_score:
                best_score = reward
                best_arch = arch_config
                if self.fitness_type == "mdl":
                    print(f"New best architecture found! MDL Score: {-best_score:.2f}")
                else:
                    print(f"New best architecture found! Validation loss: {val_loss:.6f}")
        
        print("Search completed!")
        print(f"Best architecture: {best_arch}")
        if self.fitness_type == "mdl":
            print(f"Best MDL score: {-best_score:.2f}")
        else:
            print(f"Best validation loss: {-best_score:.6f}")
        
        return best_arch, self.history
