import torch
import copy

from transformers import AutoModelForSequenceClassification
from peft import LoraConfig, get_peft_model

from plora_func.LoRAConfig import PLoraConfig
from plora_func.plora_and_get_peft import get_peft_model as get_plora_peft_model


def model_setup(args):
    """
    Initialize a global model and apply LoRA/PLoRA adaptation.

    This function loads a backbone model, attaches LoRA modules depending on
    whether FedPLoRA mode is activated, and returns:
        - args (unchanged reference)
        - the initialized model (net_glob)
        - global model state_dict (for aggregation)
        - total parameter count (vector representation size)

    Example currently demonstrates BERT-base for classification,
    but the structure is extendable to more models.
    """
    
    # ============================= Model Backbone ============================= #
    if args.model == "bert-base-uncased":
        model = AutoModelForSequenceClassification.from_pretrained(
            args.model,
            num_labels=args.num_classes
        )

        # ------------------ Option A: FedPLoRA (tied & sparse) ------------------ #
        if "fedplora" in args.model_heterogeneity:
            config = PLoraConfig(
                r=args.max_rank,
                lora_alpha=args.max_rank,
                target_modules=["query", "value"],   # Example target modules
                lora_dropout=0.1,
                bias="none",
                num_layer=args.num_tied_layer,
                density=args.density,
                sparsity_type=args.sparsity_type,
                shared_adapter=True
            )
            net_glob = get_plora_peft_model(model, config)

        # ------------------ Option B: Standard LoRA ------------------ #
        else:
            config = LoraConfig(
                r=args.max_rank,
                lora_alpha=args.max_rank,
                target_modules=["query", "value"],
                lora_dropout=0.1,
                bias="none"
            )
            net_glob = get_peft_model(model, config)

        net_glob.to(args.device)

    else:
        raise ValueError(f"Unrecognized base model: {args.model}")

    # ====================================================================== #
    #   Global model weights stored for server aggregation reference        #
    # ====================================================================== #
    global_model = copy.deepcopy(net_glob.state_dict())
    return args, net_glob, global_model, model_dim(global_model)


def model_dim(state_dict):
    """
    Count total number of parameters in a state_dict after flattening.

    Returns
    -------
    int
        Total dimension = length of concatenated parameter vectors.
    """
    return sum(param.numel() for param in state_dict.values())