import torch
import torch.nn as nn
from deguc.routing.router import HierarchicalRouter
from deguc.compression.group_lowrank import GroupSharedLowRank
from deguc.quantization.quantizer import SimpleWeightQuantizer
from deguc.offload.offload_manager import ExpertOffloadManager
from deguc.core.stats import global_stats
from deguc.distributed.communicator import DistributedCommunicatorPlaceholder

class DEGUCModel(nn.Module):
    def __init__(self, input_dim=512, output_dim=512, num_initial_experts=16, init_groups=4,
                 rank=16, top_k=2, device=None,
                 enable_int8=False, weight_only_int8=True, try_full_int8=False,
                 param_dtype=None):
        super().__init__()
        self.device = device or torch.device("cpu")
        experts = list(range(num_initial_experts))
        groups = {g: experts[g::init_groups] for g in range(init_groups)}
        self.compression = GroupSharedLowRank(
            input_dim, output_dim, groups, rank=rank,
            device=self.device, dtype=param_dtype
        )
        self.router = HierarchicalRouter(input_dim, groups, top_k=top_k, device=self.device)
        self.quantizer = SimpleWeightQuantizer(
            use_int8_cache=enable_int8,
            weight_only=weight_only_int8,
            try_full_int8=try_full_int8
        )
        self.offloader = ExpertOffloadManager()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.rank = rank
        self.communicator = DistributedCommunicatorPlaceholder()
        self.to(self.device)

    def forward(self, hidden):
        routing, balance_loss = self.router(hidden)
        experts_needed = set(e for _, lst in routing for e,_ in lst)
        reloaded = 0
        for e in experts_needed:
            changed = self.offloader.reload_if_needed(
                e, self.compression, self.input_dim, self.rank, self.output_dim, hidden.dtype
            )
            if changed:
                reloaded += 1
                global_stats.reloaded_experts += 1
                if self.quantizer.int8_enabled:
                    self.quantizer.update_single_expert(e, self.compression)
        if self.quantizer.int8_enabled:
            out = self.quantizer.forward_experts_int8(hidden, routing, self.compression)
        else:
            out = self.compression.forward_experts(hidden, routing)
        for token_idx, lst in routing:
            for e,_ in lst:
                global_stats.expert_stats[e].update(hidden[token_idx:token_idx+1])
        return out, balance_loss, {"reloaded": reloaded}

    def apply_quantization(self, build_int8_cache=True):
        self.quantizer.quantize_module(self.compression)
        self.quantizer.replace_forward_weights(self.compression)
        if build_int8_cache and self.quantizer.use_int8_cache:
            self.quantizer.build_int8_cache(self.compression)
        return self.quantizer.compression_report(self.compression)

    def update_groups(self, new_map):
        self.compression.update_group_map(new_map)
        self.router.update_group_map(new_map)
        if self.quantizer.int8_enabled:
            self.quantizer.build_int8_cache(self.compression)

    def offload_inactive(self, min_rate=0.0005):
        from deguc.core.stats import global_stats
        for exp_id, stat in list(global_stats.expert_stats.items()):
            rate = global_stats.activation_rate(exp_id)
            if rate < min_rate:
                self.offloader.offload(exp_id, self.compression)
                global_stats.offloaded_experts += 1
                if self.quantizer.int8_enabled and exp_id in self.quantizer.int8_cache.cache:
                    del self.quantizer.int8_cache.cache[exp_id]