import torch.nn as nn
from .adapter import Adapter

class AdapterManager(nn.Module):
    def __init__(self, num_layers, d_model, num_tasks=1, bottleneck=64, init_type="adapter",
                 init_layers=None, dropout=0.0, adapter_scalar=0.1, layernorm_option="in"):
        super().__init__()
        self.num_layers = num_layers
        self.num_tasks = num_tasks
        self.init_layers = set(init_layers) if init_layers is not None else set(range(num_layers))

        self.adapters = nn.ModuleDict()
        for task_id in range(num_tasks):
            for layer_id in range(num_layers):
                if layer_id in self.init_layers:
                    key = f"task{task_id}_layer{layer_id}"
                    self.adapters[key] = Adapter(
                        d_model=d_model,
                        bottleneck=bottleneck,
                        dropout=dropout,
                        init_option=init_type,
                        adapter_scalar=adapter_scalar,
                        adapter_layernorm_option=layernorm_option,
                    )

    def get(self, layer_id, task_id=0):
        key = f"task{task_id}_layer{layer_id}"
        try:
            return self.adapters[key]
        except KeyError:
            return None

    def freeze_all(self):
        for adapter in self.adapters.values():
            for p in adapter.parameters():
                p.requires_grad = False

    def freeze_all_except_task(self, task_id):
        for key, adapter in self.adapters.items():
            requires_grad = f"task{task_id}_" in key
            for p in adapter.parameters():
                p.requires_grad = requires_grad

    def unfreeze_layers_for_task(self, layer_ids, task_id):
        for i in layer_ids:
            key = f"task{task_id}_layer{i}"
            if key in self.adapters:
                for p in self.adapters[key].parameters():
                    p.requires_grad = True