import torch
import numpy as np
from tqdm import tqdm
from accelerate import Accelerator, DistributedDataParallelKwargs


class LocalUpdate:
    """
    Local training helper for one client in a federated setting.

    This class runs LoRA style fine tuning for a single client by
    iterating over its local dataloader with Accelerate support.
    """

    def __init__(self, args):
        """
        Store the global argument object for later use.

        Parameters
        ----------
        args :
            Argument object that contains training hyperparameters such as
            tau, local_lr, max_steps, gradient_accumulation_steps and logger.
        """
        self.args = args

    def lora_tuning_tied_weights(
        self,
        model,
        ldr_train,
        args,
        client_index,
        client_real_id,
        round,
        hete_group_id,
    ):
        """
        Run local LoRA training for one client with Accelerate.

        Parameters
        ----------
        model :
            Model to be fine tuned on this client.
        ldr_train :
            Dataloader that yields batches of tokenized training data.
        args :
            Argument object passed through from the main script.
        client_index :
            Index of the client in the simulation.
        client_real_id :
            Real identifier of the client if needed for logging.
        round :
            Current federated round index.
        hete_group_id :
            Heterogeneity group identifier of this client.

        Returns
        -------
        state_dict :
            State dictionary of the locally updated model.
        avg_loss :
            Average local training loss over all steps and tau iterations.
        no_weight_lora :
            A list placeholder for LoRA modules that should not be updated.
        """
        ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
        accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])

        # This list can be used to freeze a subset of LoRA parameters by name
        no_weight_lora = []

        # Build parameter groups so that selected LoRA parameters can have zero learning rate
        optimizer_grouped_parameters = [
            {
                "params": [
                    p
                    for n, p in model.named_parameters()
                    if not any(str(nd) in n for nd in no_weight_lora)
                ]
            },
            {
                "params": [
                    p
                    for n, p in model.named_parameters()
                    if any(str(nd) in n for nd in no_weight_lora)
                ],
                "lr": 0.0,
            },
        ]

        optimizer = torch.optim.AdamW(
            optimizer_grouped_parameters, lr=args.local_lr
        )

        # Prepare model, optimizer and dataloader for Accelerate
        model, optimizer, train_dataloader = accelerator.prepare(
            model, optimizer, ldr_train
        )

        total_loss = []
        max_steps = getattr(args, "max_steps", len(train_dataloader))

        # Tau controls how many local epochs or passes we run on this client
        for t_au in range(self.args.tau):
            with accelerator.accumulate(model):
                for step, batch in tqdm(
                    enumerate(train_dataloader),
                    desc=f"Local Training Client {client_index} Tau: {t_au}",
                    total=len(train_dataloader),
                    disable=not accelerator.is_local_main_process,
                ):
                    outputs = model(**batch)
                    loss = outputs.loss
                    accelerator.backward(loss)

                    if not hasattr(args, "gradient_accumulation_steps"):
                        args.gradient_accumulation_steps = 1

                    if (step + 1) % args.gradient_accumulation_steps == 0:
                        optimizer.step()
                        optimizer.zero_grad()

                        if accelerator.is_local_main_process:
                            total_loss.append(loss.detach().float().cpu())

                    accelerator.wait_for_everyone()

                    # Respect max_steps if it is smaller than the full dataloader
                    if step + 1 > max_steps:
                        break

        avg_loss = float(np.mean(total_loss)) if len(total_loss) > 0 else 0.0
        args.logger.info(
            f"Total local training loss is: {avg_loss}",
            main_process_only=True,
        )

        # Unwrap the model back to the original type before returning the state dict
        return accelerator.unwrap_model(model).state_dict(), avg_loss, no_weight_lora