import copy 
import numpy as np
import torch 
from tqdm import tqdm 
import wandb

from utils.train_utils import train, eval, create_optimizer
from utils.model_utils import copy_model, models_are_equal
from utils.het_utils import truncate_model, flex_lora_model, zero_pad, freeze_from_rankings

""" Homogeneous Rank Aggregation and Training """
def agg_models(server_model,
               client_states,       
               agg_method,
               adaptation_method,
               num_heads: int = 0):

    def stack_and_mean(param_name):
      tensors = [cs[param_name].to(server_model.device)  
                for cs in client_states]
      return torch.stack(tensors, dim=0).mean(dim=0)

    if adaptation_method == "full_ft":
        with torch.no_grad():
            for name, _ in server_model.named_parameters():
                server_model.state_dict()[name].copy_(stack_and_mean(name))

    elif adaptation_method == "lora":
        for name, _ in server_model.named_parameters():
            if "base_layer" in name and "weight" in name:
                # derive companion LoRA names
                name_A = name.replace("base_layer.weight", "lora_A")
                name_B = name.replace("base_layer.weight", "lora_B")

                avg_A   = stack_and_mean(name_A)
                avg_B   = stack_and_mean(name_B)
                avg_prod = torch.stack(
                    [(cs[name_B] @ cs[name_A]).to(server_model.device)
                     for cs in client_states]
                ).mean(dim=0)

                with torch.no_grad():
                    if agg_method == "FedAvg":
                        server_model.state_dict()[name_A].copy_(avg_A)
                        server_model.state_dict()[name_B].copy_(avg_B)
                    elif agg_method == "FFA":
                        server_model.state_dict()[name_B].copy_(avg_B)
                    elif agg_method == "FedExLoRA":
                        server_model.state_dict()[name].add_(avg_prod -
                                                             (avg_B @ avg_A))
                        server_model.state_dict()[name_A].copy_(avg_A)
                        server_model.state_dict()[name_B].copy_(avg_B)

    elif adaptation_method in {"vravan", "sb"}:
        with torch.no_grad():
            for name, p in server_model.named_parameters():
                if p.requires_grad:
                    server_model.state_dict()[name].copy_(stack_and_mean(name))

    elif adaptation_method == "ravan":
        for name, _ in server_model.named_parameters():
            if "base_layer" in name and "weight" in name:
                for i in range(num_heads):
                    name_R       = name.replace("base_layer.weight", f"lora_R_{i}")
                    name_scaling = name.replace("base_layer.weight", f"lora_scaling_{i}")
                    avg_R = torch.stack(
                        [(cs[name_scaling] * cs[name_R]).to(server_model.device)
                         for cs in client_states]
                    ).mean(dim=0)

                    with torch.no_grad():
                        server_model.state_dict()[name_R].copy_(avg_R)
                        server_model.state_dict()[name_scaling].fill_(1.0)
    else:
        raise NotImplementedError

    return server_model

def fl_training(server_model, clients, testloader, tokenizer, args):
    print("Pretrained model:")
    acc = eval(server_model, testloader, args.dataset, tokenizer=tokenizer)
    if args.wandb_logging:
        wandb.log({"global_acc": acc})

    for rnd in range(args.comm_rounds):
        print(f"Training communication round: {rnd}")
        client_ids     = torch.randperm(len(clients))[: args.clients_round]
        client_states  = []  

        for _, cid in tqdm(enumerate(client_ids), total=len(client_ids)):
            client_model = copy_model(server_model, args.dataset,
                                      args.rank, args.alpha,
                                      args.b_var, args.r_var, args.a_var,
                                      args.num_heads, args.adaptation_method
                                     ).to("cuda")

            if args.aggregation_method == "FFA":
                for n, p in client_model.named_parameters():
                    if "lora_A" in n:
                        p.requires_grad = False

            opt = create_optimizer(client_model, args.optimizer,
                                   args.client_lr, args.momentum)

            for _ in range(args.epochs):
                client_model = train(client_model,
                                     clients[cid],
                                     opt,
                                     args.dataset,
                                     tokenizer=tokenizer,
                                     steps=args.local_steps)

            state_cpu = {k: v.cpu() for k, v in client_model.state_dict().items()}
            client_states.append(state_cpu)

            del client_model, opt
            torch.cuda.empty_cache()

        agg_models(server_model, client_states,
                   args.aggregation_method, args.adaptation_method,
                   num_heads=args.num_heads)
        del client_states
        torch.cuda.empty_cache()

        # if (rnd + 1) % 10 == 0:
        if (rnd + 1) % 1 == 0:
            print(f"Evaluation at round: {rnd + 1}")
            acc = eval(server_model, testloader, args.dataset, tokenizer=tokenizer)
            if args.wandb_logging:
                wandb.log({"global_acc": acc})

""" Heterogeneous Rank Aggregation and Training """

def agg_models_het(
        server_model,
        client_states,          
        client_ranks,        
        agg_method,
        adaptation_method,
        num_heads: int = 0,
):
    """Aggregate LoRA clients without ever instantiating full models."""
    if adaptation_method != "lora":
        raise NotImplementedError("Het path currently supports LoRA only.")

    max_r = max(client_ranks)                  
    device = server_model.device

    avg_prod_dict = {}
    for name, _ in server_model.named_parameters():
        if "base_layer" not in name or "weight" not in name:
            continue

        name_A = name.replace("base_layer.weight", "lora_A")
        name_B = name.replace("base_layer.weight", "lora_B")

        As = [cs[name_A] for cs in client_states]
        Bs = [cs[name_B] for cs in client_states]
        prods = [B @ A for B, A in zip(Bs, As)]

        if agg_method == "HetLoRA":
            padded_As = [zero_pad(A, max_r, pad_B=False) for A in As]
            padded_Bs = [zero_pad(B, max_r, pad_B=True)  for B in Bs]

        norms = [torch.norm(p, p="fro") for p in prods]
        total_norm = torch.stack(norms).sum()
        proportions = [n / total_norm for n in norms]

        with torch.no_grad():
            if agg_method == "FlexLoRA":
                pass

            elif agg_method == "HetLoRA":
                weighted_A = sum(prop * padA for prop, padA in zip(proportions, padded_As))
                weighted_B = sum(prop * padB for prop, padB in zip(proportions, padded_Bs))

                server_model.state_dict()[name_A].copy_(weighted_A.to(device))
                server_model.state_dict()[name_B].copy_(weighted_B.to(device))

            else:
                raise ValueError(f"Unknown heterogeneous agg_method: {agg_method}")

        avg_prod_dict[name] = torch.stack(prods).mean(dim=0)

        torch.cuda.empty_cache()

    return (server_model, avg_prod_dict) if agg_method == "FlexLoRA" else (server_model, None)


def fl_training_het(
        server_model,
        clients,
        testloader,
        tokenizer,
        client_ranks,
        args,
):

    print("Pretrained model:")
    acc = eval(server_model, testloader, args.dataset, tokenizer=tokenizer)
    if args.wandb_logging:
        wandb.log({"global_acc": acc})

    avg_prod_dict = None                                   
    for rnd in range(args.comm_rounds):
        print(f"Training communication round: {rnd}")
        client_ids = torch.randperm(len(clients))[:args.clients_round]
        client_states = []                                   
        round_ranks = []

        for cid in tqdm(client_ids, total=len(client_ids)):
            rank = int(client_ranks[cid] * args.rank)
            if args.aggregation_method == "FlexLoRA":
                if rnd == 0:
                    client_model = truncate_model(server_model, args.dataset, rank, args.alpha, args.b_var, args.a_var)
                else:
                    client_model = flex_lora_model(server_model, args.dataset, rank, args.alpha, args.b_var, args.a_var, avg_prod_dict)
                client_data = clients[cid]
                client_optimizer = create_optimizer(client_model, args.optimizer, args.client_lr, args.momentum)
            elif args.aggregation_method == "HetLoRA":
                client_model = truncate_model(server_model, args.dataset, rank, args.alpha, args.b_var, args.a_var)
                client_data = clients[cid]
                client_optimizer = create_optimizer(client_model, args.optimizer, args.client_lr, args.momentum)
            else:
                client_model = copy_model(server_model, args.dataset, args.rank, args.alpha, args.b_var, args.r_var, args.a_var, args.num_heads, args.adaptation_method)
                client_data = clients[cid]
                client_optimizer = create_optimizer(client_model, args.optimizer, args.client_lr, args.momentum)
                freeze_from_rankings(client_model, client_ranks[cid], args.num_heads, args.ranking, client_data, client_optimizer, args.dataset, tokenizer)

            for _ in range(args.epochs):
                client_model = train(client_model, clients[cid], client_optimizer, args.dataset, tokenizer=tokenizer, steps=args.local_steps)

            state_cpu = {k: v.cpu() for k, v in client_model.state_dict().items()}
            client_states.append(state_cpu)
            round_ranks.append(rank)

            del client_model,  client_optimizer
            torch.cuda.empty_cache()

        if args.aggregation_method in {"FlexLoRA", "HetLoRA"}:
            server_model, avg_prod_dict = agg_models_het(server_model, client_states, round_ranks, args.aggregation_method, args.adaptation_method, num_heads=args.num_heads)
        else:  
            agg_models(server_model, client_states, args.aggregation_method, args.adaptation_method, num_heads=args.num_heads)
            
        del client_states
        torch.cuda.empty_cache()

        if (rnd + 1) % 10 == 0:
            print(f"Evaluation at round {rnd + 1}")
            if args.aggregation_method == "FlexLoRA":
                server_model = flex_lora_model(server_model, args.dataset,args.rank, args.alpha, args.b_var, args.a_var, avg_prod_dict)
            acc = eval(server_model, testloader, args.dataset, tokenizer=tokenizer)
            if args.wandb_logging:
                wandb.log({"global_acc": acc})