#!/usr/bin/env python3
# -*- coding: utf-8 -*-


import torch
import torch.nn as nn
from typing import Dict, Optional

class BudgetController:

    def __init__(self, model, config: Dict):
        self.model = model
        self.config = config
        self.current_budget = None
        
        self.budget_levels = {
            64: {
                'node_budget': 4,      
                'max_segments': 8,  
                'num_heads': 2,
                'head_dim': 16,
                'gma_scales': [8, 16],
                'filter_orders': [3, 5],
                'num_filters': 4,
                'wavelet_levels': 2,
                'fft_bins': 16,
                'use_fft': True,
                'use_wavelet': False,
                'frequency_branch_enabled': False
            },
            128: {
                'node_budget': 8,      
                'max_segments': 12,  
                'num_heads': 4,
                'head_dim': 24,
                'gma_scales': [8, 16, 32],
                'filter_orders': [3, 5, 7],
                'num_filters': 6,
                'wavelet_levels': 3,
                'fft_bins': 24,
                'use_fft': True,
                'use_wavelet': True,
                'frequency_branch_enabled': True
            },
            256: {
                'node_budget': 16,     
                'max_segments': 20,    
                'num_heads': 6,
                'head_dim': 32,
                'gma_scales': [8, 16, 32, 64],
                'filter_orders': [3, 5, 7, 9],
                'num_filters': 8,
                'wavelet_levels': 4,
                'fft_bins': 32,
                'use_fft': True,
                'use_wavelet': True,
                'frequency_branch_enabled': True
            },
            512: {
                'node_budget': 24,     
                'max_segments': 32, 
                'num_heads': 8,
                'head_dim': 32,
                'gma_scales': [8, 16, 32, 64],
                'filter_orders': [3, 5, 7, 9],
                'num_filters': 8,
                'wavelet_levels': 4,
                'fft_bins': 32,
                'use_fft': True,
                'use_wavelet': True,
                'frequency_branch_enabled': True
            },
            1024: {
                'node_budget': 32,   
                'max_segments': 48,   
                'num_heads': 8,
                'head_dim': 32,
                'gma_scales': [8, 16, 32, 64],
                'filter_orders': [3, 5, 7, 9],
                'num_filters': 8,
                'wavelet_levels': 4,
                'fft_bins': 32,
                'use_fft': True,
                'use_wavelet': True,
                'frequency_branch_enabled': True
            }
        }
    
    def set_budget(self, budget: int):

        self.current_budget = budget
        if budget in self.budget_levels:
            budget_config = self.budget_levels[budget]
        else:

            budget_config = self._calculate_dynamic_budget_config(budget)
        

        import os
        os.environ['USE_NBC'] = 'True'
        
        self._apply_budget_config(budget_config)
    
    def _apply_budget_config(self, budget_config: Dict):

        budget_config['use_nbc'] = True
        
        if hasattr(self.model, 'config'):
            self.model.config.update(budget_config)
        
        if hasattr(self.model, 'seg_mod') and self.model.seg_mod is not None:
            if hasattr(self.model.seg_mod, 'config'):
                self.model.seg_mod.config.update(budget_config)
        
    
    def _calculate_dynamic_budget_config(self, budget: int) -> Dict:

        min_budget, max_budget = 32, 1024
        min_nodes, max_nodes = 2, 32
        min_segments, max_segments = 4, 48
        min_heads, max_heads = 2, 8
        min_head_dim, max_head_dim = 16, 32

        ratio = (budget - min_budget) / (max_budget - min_budget)
        ratio = max(0, min(1, ratio))  
        
        node_budget = int(min_nodes + ratio * (max_nodes - min_nodes))
        
        max_segments = int(min_segments + ratio * (max_segments - min_segments))
        num_heads = int(min_heads + ratio * (max_heads - min_heads))
        head_dim = int(min_head_dim + ratio * (max_head_dim - min_head_dim))
            
        if budget <= 64:
            use_fft = True
            use_wavelet = False
            frequency_branch_enabled = False
            gma_scales = [8, 16]
            filter_orders = [3, 5]
            num_filters = 4
            wavelet_levels = 2
            fft_bins = 16
        elif budget <= 256:
            use_fft = True
            use_wavelet = True
            frequency_branch_enabled = True
            gma_scales = [8, 16, 32]
            filter_orders = [3, 5, 7]
            num_filters = 6
            wavelet_levels = 3
            fft_bins = 24
        else:
            use_fft = True
            use_wavelet = True
            frequency_branch_enabled = True
            gma_scales = [8, 16, 32, 64]
            filter_orders = [3, 5, 7, 9]
            num_filters = 8
            wavelet_levels = 4
            fft_bins = 32
        
        return {
            'node_budget': node_budget,
            'max_segments': max_segments,
            'num_heads': num_heads,
            'head_dim': head_dim,
            'gma_scales': gma_scales,
            'filter_orders': filter_orders,
            'num_filters': num_filters,
            'wavelet_levels': wavelet_levels,
            'fft_bins': fft_bins,
            'use_fft': use_fft,
            'use_wavelet': use_wavelet,
            'frequency_branch_enabled': frequency_branch_enabled
        }
    
    def get_budget(self) -> Optional[int]:

        return self.current_budget
    
    def get_budget_info(self) -> Dict:

        if self.current_budget is None:
            return {}
        
        if self.current_budget in self.budget_levels:
            return self.budget_levels[self.current_budget].copy()
        else:
            return self._calculate_dynamic_budget_config(self.current_budget)

def add_budget_control_to_model(model: nn.Module, config: Dict) -> nn.Module:

    controller = BudgetController(model, config)
    
    model.set_budget = controller.set_budget
    model.get_budget = controller.get_budget
    model.get_budget_info = controller.get_budget_info
    
    model.budget = controller.get_budget()  
    
    return model