import torch
from torch import nn
from adapters import AdapterLinear, MultiHeadClassifier
from transformers.models.roberta.modeling_roberta import RobertaForSequenceClassification, RobertaClassificationHead

def add_adapters(net: RobertaForSequenceClassification, adapter_dim: int, num_tasks: int, num_labels_list: list, alpha: int, p_dropout, only_qv = False):
    """Transform Linear layers to adapter layers"""

    # Don't add adapters to classification head
    for module in list(net.roberta.modules()):
        for name, child in module.named_children():
            # Only adapt query and value matrices if specified
            if only_qv and name != "query" and name != "value":
                continue
            if isinstance(child, nn.modules.linear.Linear):
                setattr(module, name, AdapterLinear.from_linear(child, adapter_dim=adapter_dim, num_tasks=num_tasks, lora_alpha=alpha, p_dropout = p_dropout))
    net.classifier = MultiHeadClassifier.from_roberta_head(net.classifier, num_labels_list)

def set_active_task(net, task_idx):
    for module in list(net.modules()):
        if isinstance(module, AdapterLinear) or isinstance(module, MultiHeadClassifier):
            setattr(module, "active_task", task_idx)
    net.num_labels = net.classifier.head[task_idx].out_proj.out_features

def freeze_base_thaw_adapters(net):
    for name, param in net.named_parameters():
        if "adapters" in name or "classifier" in name:
            param.requires_grad = True
        else:
            param.requires_grad = False

def freeze_adapters_thaw_base(net):
    for name, param in net.named_parameters():
        if "adapters" in name or "classifier" in name:
            param.requires_grad = False
        else:
            param.requires_grad = True

def freeze_lora_thaw_base(net):
    for name, param in net.named_parameters():
        if "adapters" in name:
            param.requires_grad = False
        else:
            param.requires_grad = True

def freeze_head(net):
    for name, param in net.named_parameters():
        if "classifier" in name:
            param.requires_grad = False

def thaw(net):
    for _,param in net.named_parameters():
        param.requires_grad = True

def instantiate_model(net, state_dict):
    with torch.no_grad():
        for key in state_dict.keys():
            key_list = key.split('.')
            att = getattr(net,key_list[0])
            for i in range(1,len(key_list)):
                att = getattr(att,key_list[i])
            att.copy_(state_dict[key])

def instantiate_multi_model(net, state_dict_list):
    with torch.no_grad():
        instantiate_model(net,state_dict_list[0])
        for task, state_dict in enumerate(state_dict_list[1:]):
            task = task + 1
            task_str = str(task)
            for key in state_dict.keys():
                if "adapters."in key:
                    new_key = key.replace("adapters.0","adapters."+task_str)
                    key_list = new_key.split('.')
                    att = getattr(net,key_list[0])
                    for i in range(1,len(key_list)):
                        att = getattr(att,key_list[i])
                    att.copy_(state_dict[key])
                if "classifier.head." in key:
                    new_key = key.replace("head.0","head."+task_str)
                    key_list = new_key.split('.')
                    att = getattr(net,key_list[0])
                    for i in range(1,len(key_list)):
                        att = getattr(att,key_list[i])
                    att.copy_(state_dict[key])

def instantiate_base_model(net,state_dict):
    with torch.no_grad():
        for key in state_dict.keys():
            if "classifier" in key or "adapters" in key:
                continue
            key_list = key.split('.')
            att = getattr(net,key_list[0])
            for i in range(1,len(key_list)):
                att = getattr(att,key_list[i])
            att.copy_(state_dict[key])
    return

def set_dropout(net,p):
    for mod in net.modules():
        if isinstance(mod,torch.nn.modules.dropout.Dropout):
            mod.p = p