from typing import List, Optional, Tuple
from vit import ViT_B_16
import torch.nn as nn
from peft import LoraConfig, AdaLoraConfig, VeraConfig, get_peft_model

def build_model(
    ft_mode: str = "linear",
    num_classes: int = 100,
    target_modules: List[str] = ["query", "value"],
    save_modules:  List[str] = ["heads.head"],
    lora_r: int = 8,
    lora_gamma: float = 0.0,
    lora_sp: str = "nuc",
    lora_dropout: float = 0.,
    init_r: int = 12,
    total_steps: int = 2000,
    deltaT: int = 1,
    orth_reg_weight: float = 0.5
):
    model = ViT_B_16(target_num_classes=num_classes)
    if ft_mode == "full":
        for name, param in model.named_parameters():
            if 'heads' in name or 'query' in name or 'value' in name:
                param.requires_grad = True 
            else:
                param.requires_grad = False
    elif ft_mode == "linear":
        for name, param in model.named_parameters():
            if 'heads' in name:
                param.requires_grad = True 
            else:
                param.requires_grad = False
    else:
        if ft_mode == "lora":
            config = LoraConfig(
                    r=lora_r,
                    lora_alpha=lora_r,
                    lora_dropout=lora_dropout,
                    target_modules=target_modules,
                    modules_to_save=save_modules,
                )
        elif ft_mode == "dora":
            config = LoraConfig(
                    r=lora_r,
                    lora_alpha=lora_r,
                    use_dora=True,
                    lora_dropout=lora_dropout,
                    target_modules=target_modules,
                    modules_to_save=save_modules,
                )
        elif ft_mode == "pissa":
            config = LoraConfig(
                    r=lora_r,
                    lora_alpha=lora_r,
                    use_dora=True,
                    lora_dropout=lora_dropout,
                    init_lora_weights="pissa_niter_16",
                    target_modules=target_modules,
                    modules_to_save=save_modules,
                )
        elif ft_mode == "nblora":
            config = LoraConfig(
                    r=lora_r,
                    lora_alpha=lora_r,
                    lora_gamma=lora_gamma,
                    lora_sp=lora_sp,
                    lora_dropout=lora_dropout,
                    init_lora_weights="nblora_init",
                    target_modules=target_modules,
                    modules_to_save=save_modules,
                )
        elif ft_mode == "vera":
            config = VeraConfig(
                    r=lora_r,
                    vera_dropout=lora_dropout,
                    d_initial=0.1,
                    target_modules=target_modules,
                    modules_to_save=save_modules,
                )
        elif ft_mode == "adalora":
            tinit = int(0.1 * total_steps)
            tfinal = tinit * 2
            config = AdaLoraConfig(
                    lora_alpha=lora_r,
                    lora_dropout=lora_dropout,
                    target_r=lora_r,
                    init_r=init_r,
                    tinit=tinit,
                    tfinal=tfinal,
                    deltaT=deltaT,
                    orth_reg_weight=orth_reg_weight,
                    total_step=total_steps,
                    target_modules=target_modules,
                    modules_to_save=save_modules,
                )
            
        model = get_peft_model(model, config)

    return model
