import torch
import torch.nn as nn
import torch.quantization
from typing import List, Dict, Any
import numpy as np

class ModelCompression:
    def __init__(self, model: nn.Module):
        self.model = model

    def prune_model(self, amount: float = 0.3):
        """模型剪枝"""
        for name, module in self.model.named_modules():
            if isinstance(module, (nn.Conv2d, nn.Linear)):
                prune_mask = torch.rand_like(module.weight) > amount
                module.weight.data *= prune_mask
                if module.bias is not None:
                    module.bias.data *= prune_mask.any(dim=1)

    def quantize_model(self, dtype: torch.dtype = torch.qint8):
        """模型量化"""
        self.model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
        torch.quantization.prepare(self.model, inplace=True)
        # 这里需要在实际数据上进行校准
        torch.quantization.convert(self.model, inplace=True)

class EdgeOptimizer:
    def __init__(self, model: nn.Module):
        self.model = model
        self.compression = ModelCompression(model)

    def optimize_for_edge(self, compression_ratio: float = 0.3):
        """边缘设备优化"""
        # 1. 模型剪枝
        self.compression.prune_model(compression_ratio)
        
        # 2. 模型量化
        self.compression.quantize_model()
        
        # 3. 模型大小估算
        model_size = self._estimate_model_size()
        return model_size

    def _estimate_model_size(self) -> float:
        """估算模型大小（MB）"""
        param_size = 0
        for param in self.model.parameters():
            param_size += param.nelement() * param.element_size()
        buffer_size = 0
        for buffer in self.model.buffers():
            buffer_size += buffer.nelement() * buffer.element_size()
        size_all_mb = (param_size + buffer_size) / 1024**2
        return size_all_mb

class AdaptiveCompression:
    def __init__(self, model: nn.Module):
        self.model = model
        self.optimizer = EdgeOptimizer(model)

    def adaptive_compress(self, target_size_mb: float, max_iterations: int = 10):
        """自适应压缩直到达到目标大小"""
        current_size = self.optimizer._estimate_model_size()
        compression_ratio = 0.3  # 初始压缩比例
        
        for _ in range(max_iterations):
            if current_size <= target_size_mb:
                break
                
            # 调整压缩比例
            compression_ratio = min(0.9, compression_ratio + 0.1)
            self.optimizer.optimize_for_edge(compression_ratio)
            current_size = self.optimizer._estimate_model_size()
            
        return current_size, compression_ratio 