import copy
import gc
import pickle
import random
from fractions import Fraction

import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from transformers import AutoModelForSequenceClassification

from plora_func.LoRAConfig import PLoraConfig
from plora_func.plora_and_get_peft import (
    get_peft_model as get_tied_peft_model
)
from utils.data_pre_process import load_partition, DatasetSplit
from utils.global_aggregator import average_lora_fedplora
from utils.local_solver import LocalUpdate
from utils.model_utils import model_setup
from test import test_mrpc
import re


def fedplora(args):
    """
    Main FedPLoRA training loop.

    This function handles:
      1 data partition loading
      2 model initialization
      3 client side local fine tuning with tied LoRA
      4 server side aggregation of LoRA updates
      5 periodic evaluation on a held out test set

    Returns
    -------
    tuple
        Best metrics in order
        (Accuracy, F1, Macro F1, Matthews correlation, Micro F1, RougeL)
    dict
        A dict of metric flags indicating which ones are valid for this run
    """
    # ===================== Data and logging setup ===================== #
    args.logger.info(
        "{:<50}".format("-" * 15 + " data setup " + "-" * 50)[0:60],
        main_process_only=True,
    )
    (
        args,
        dataset_train,
        dataset_test,
        dataset_val,
        dataset_public,
        dict_users,
        dataset_fim,
    ) = load_partition(args)

    args.logger.info(
        f"length of dataset:{len(dataset_train) + len(dataset_test)}",
        main_process_only=True,
    )
    args.logger.info(
        f"num. of training data:{len(dataset_train)}",
        main_process_only=True,
    )

    if isinstance(dataset_test, dict):
        num_test = sum(len(v) for v in dataset_test.values())
    else:
        num_test = len(dataset_test)
    args.logger.info(
        f"num. of testing data:{num_test}",
        main_process_only=True,
    )

    args.logger.info(
        f"num. of users:{len(dict_users)}",
        main_process_only=True,
    )

    sample_per_users = int(
        sum(len(dict_users[i]) for i in range(len(dict_users))) / len(dict_users)
    )
    args.logger.info(
        f"average num. of samples per user:{sample_per_users}",
        main_process_only=True,
    )

    args.logger.info(
        "{:<50}".format("-" * 15 + " log path " + "-" * 50)[0:60],
        main_process_only=True,
    )
    if args.accelerator.is_local_main_process:
        writer = SummaryWriter(args.log_path)
    args.logger.info(args.log_path, main_process_only=True)

    # ======================= Model setup ======================= #
    args.logger.info(
        "{:<50}".format("-" * 15 + " model setup " + "-" * 50)[0:60],
        main_process_only=True,
    )
    args, net_glob, global_model, args.dim = model_setup(args)
    args.logger.info(f"model dim: {args.dim}", main_process_only=True)

    # ======================= Dataloader per client ======================= #
    args.logger.info(
        "{:<50}".format("-" * 15 + " training... " + "-" * 50)[0:60],
        main_process_only=True,
    )

    data_loader_list = []
    for i in range(args.num_users):
        # For environments with limited GPU memory, sync dict_users across processes
        if args.batch_size < 1 and args.iid == 0:
            # Step 1 main process serializes the dictionary and determines its size
            if args.accelerator.is_main_process:
                serialized_data = pickle.dumps(dict_users[i])
                dict_size = torch.tensor(
                    len(serialized_data),
                    device=args.accelerator.device,
                )
            else:
                dict_size = torch.tensor(0, device=args.accelerator.device)

            # Step 2 broadcast the size
            torch.distributed.broadcast(dict_size, src=0)

            # Step 3 broadcast the serialized dictionary
            if args.accelerator.is_main_process:
                dict_tensor = torch.ByteTensor(list(serialized_data)).to(
                    args.accelerator.device
                )
            else:
                dict_tensor = torch.empty(
                    dict_size.item(),
                    dtype=torch.uint8,
                    device=args.accelerator.device,
                )

            torch.distributed.broadcast(dict_tensor, src=0)
            received_bytes = bytes(dict_tensor.cpu().numpy())
            args.accelerator.wait_for_everyone()
            dict_users[i] = pickle.loads(received_bytes)
            args.accelerator.wait_for_everyone()

        dataset = DatasetSplit(dataset_train, dict_users[i], args)

        if "mrpc" in args.dataset:
            ldr_train = DataLoader(
                dataset,
                shuffle=True,
                collate_fn=args.data_collator,
                batch_size=args.batch_size,
            )
        else:
            raise NotImplementedError("Current example supports GLUE MRPC dataset")

        data_loader_list.append(ldr_train)
        args.accelerator.wait_for_everyone()

    # ======================= Heterogeneity setting ======================= #
    group_num = len(args.heterogeneous_group)
    group_cnt = []

    for g in range(group_num):
        if g == group_num - 1:
            remaining = args.num_users
            for c in group_cnt:
                remaining -= c
            group_cnt.append(remaining)
        else:
            group_cnt.append(
                int(args.num_users * float(Fraction(args.heterogeneous_group[g])))
            )

    args.user_groupid_list = []
    for gid, count in enumerate(group_cnt):
        args.user_groupid_list.extend([gid] * count)

    if hasattr(args, "num_tied_layer0"):
        if isinstance(getattr(args, "num_tied_layer0"), int):
            args.num_tied_layer_list = []
            for gid in args.user_groupid_list:
                args.num_tied_layer_list.append(
                    getattr(args, f"num_tied_layer{gid}")
                )

    # ======================= Metrics and tracking ======================= #
    best_test_acc = 0.0
    best_test_f1 = 0.0
    best_test_matthews_correlation = 0.0
    best_test_macro_f1 = 0.0
    best_test_micro_f1 = 0.0
    best_test_rougeL = 0.0

    metric_keys = {
        "Accuracy": 0,
        "F1": 0,
        "Macro_F1": 0,
        "Micro_F1": 0,
        "Matthews_correlation": 0,
        "RougeL": 0,
    }

    local_updates_for_local_specialty = {
        i: [] for i in range(args.num_users)
    }

    # ======================= Federated rounds ======================= #
    for t in range(args.round):
        args.logger.info(
            f"Round: {t}/{args.round}",
            main_process_only=True,
        )

        # Learning rate decay
        if (t + 1) % args.lr_step_size == 0:
            args.local_lr = args.local_lr * args.decay_weight

        # Client sampling
        selected_idxs = list(
            np.random.choice(
                range(args.num_users),
                args.num_selected_users,
                replace=False,
            )
        )

        net_glob.train()
        local_solver = LocalUpdate(args=args)

        local_losses = []
        local_updates = []
        delta_norms = []
        num_samples = []

        for num_index, i in enumerate(selected_idxs):
            # Speed up by reusing cached local specialty updates on device
            if len(local_updates_for_local_specialty[i]) != 0:
                local_updates_for_local_specialty[i] = [
                    {
                        k: local_updates_for_local_specialty[i][0][k].to(args.device)
                        for k in local_updates_for_local_specialty[i][0].keys()
                    }
                ]

            if args.peft == "lora":
                if args.model == "bert-base-uncased":
                    model_ = AutoModelForSequenceClassification.from_pretrained(
                        args.model,
                        num_labels=args.num_classes,
                    )
                    config_ = PLoraConfig(
                        r=args.max_rank,
                        lora_alpha=args.max_rank,
                        target_modules=["query", "value"],
                        lora_dropout=0.1,
                        bias="none",
                        num_layer=args.num_tied_layer,
                        density=args.density,
                        sparsity_type=args.sparsity_type,
                        shared_adapter=True,
                    )

                    net_glob_ = get_tied_peft_model(model_, config_)

                    net_glob_.to(args.device)
                else:
                    raise NotImplementedError("Current example uses BERT base")

                # Slice global LoRA ranks if needed
                model_dict = net_glob.state_dict()
                model_dict_ = net_glob.state_dict()

                for key in model_dict:
                    if "lora" in key:
                        r_dim = None
                        for index, x in enumerate(model_dict[key].size()):
                            if x == args.max_rank:
                                r_dim = index

                        if r_dim == 0:
                            if args.agg_mode == "ParallelFedAgg":
                                model_dict_[key] = model_dict[key][: args.max_rank, :]
                            else:
                                raise NotImplementedError("Aggregation mode not supported")
                        elif r_dim == 1:
                            if args.agg_mode == "ParallelFedAgg":
                                model_dict_[key] = model_dict[key][:, : args.max_rank]
                            else:
                                raise NotImplementedError("Aggregation mode not supported")
                        else:
                            raise ValueError("Unexpected LoRA dimension")

                # Remove possible quantization related keys
                model_dict_ = {
                    k: v
                    for k, v in model_dict_.items()
                    if "quant_state" not in k
                    and "absmax" not in k
                    and "quant_map" not in k
                }

                net_glob_.load_state_dict(model_dict_)
                global_model_ = copy.deepcopy(net_glob_.state_dict())

                multirandom = getattr(args, "multirandom", False)

                # Freeze partial LoRA ranks
                if not multirandom:
                    sub_lora_freezing_list = sorted(
                        random.sample(
                            range(args.num_tied_layer),
                            args.num_tied_layer - args.num_tied_layer_list[i],
                        )
                    )

                for attention_layer_index in range(args.lora_layer):
                    if multirandom:
                        sub_lora_freezing_list = random.sample(
                            range(args.num_tied_layer),
                            args.num_tied_layer - args.num_tied_layer_list[i],
                        )

                    for n, p in net_glob_.named_parameters():
                        if args.model == "bert-base-uncased":
                            pattern = f".layer.{attention_layer_index}"
                            if pattern in n:
                                for sub_lora_index in sub_lora_freezing_list:
                                    if f".{sub_lora_index}.weight" in n:
                                        p.requires_grad = False
                        else:
                            raise NotImplementedError(
                                "Layer naming pattern for this model is not supported"
                            )

                net_glob_.train()
                copy_net_glob = copy.deepcopy(net_glob_)

                local_model, local_loss, no_weight_lora = (
                    local_solver.lora_tuning_tied_weights(
                        model=copy_net_glob,
                        ldr_train=data_loader_list[i],
                        args=args,
                        client_index=num_index,
                        client_real_id=i,
                        round=t,
                        hete_group_id=args.user_groupid_list[i],
                    )
                )

            else:
                raise NotImplementedError("Current example uses LoRA only")

            if local_loss:
                local_losses.append(local_loss)

            # Compute model update on LoRA parameters
            model_update = {}
            if args.peft == "lora":
                for k in global_model_.keys():
                    if "lora" in k:
                        # Exclude ranks that were masked out
                        # Assume the index in the name is rank id
                        rank_ids = [int(s) for s in re.findall(r"\d+", k)]
                        if rank_ids and rank_ids[0] not in no_weight_lora:
                            model_update[k] = (
                                local_model[k].detach().cpu()
                                - global_model_[k].detach().cpu()
                            )
            else:
                model_update = {
                    k: local_model[k].detach().cpu()
                    - global_model_[k].detach().cpu()
                    for k in global_model_.keys()
                }

            # Replace cached local specialty updates
            if len(local_updates_for_local_specialty[i]) != 0:
                local_updates_for_local_specialty[i] = [
                    {
                        k: local_updates_for_local_specialty[i][0][k].detach().cpu()
                        for k in local_updates_for_local_specialty[i][0].keys()
                    }
                ]
            local_updates_for_local_specialty[i] = [copy.deepcopy(model_update)]

            # Compute update norm for diagnostics
            norm_updates = [torch.flatten(v) for v in model_update.values()]
            if len(norm_updates) > 0:
                delta_norm = torch.norm(torch.cat(norm_updates))
            else:
                delta_norm = None

            if delta_norm is not None:
                delta_norms.append(delta_norm)

            local_updates.append(model_update)
            num_samples.append(len(data_loader_list[i]))

            # Release local resources
            for key in local_model:
                local_model[key] = None
            del local_model

            net_glob_ = net_glob_.to("cpu")
            del net_glob_

            copy_net_glob = copy_net_glob.to("cpu")
            del copy_net_glob

            del model_

            for key in global_model_:
                global_model_[key] = None
            del global_model_

            for key in model_dict_.keys():
                model_dict_[key] = None
            del model_dict_

            for key in model_dict.keys():
                model_dict[key] = None
            del model_dict

            if "aggregated_updates" in globals():
                for key in aggregated_updates.keys():
                    aggregated_updates[key] = None
                del aggregated_updates

            torch.cuda.empty_cache()

        gc.collect()
        torch.cuda.empty_cache()

        if len(local_updates) == 0:
            args.logger.info(
                "The number of trainable clients is zero skip the round for average"
            )
            continue

        # ======================= Diagnostics ======================= #
        if len(delta_norms) > 0:
            norm = torch.median(torch.stack(delta_norms)).cpu()
        else:
            norm = torch.tensor(100.0)

        if len(local_losses) > 0:
            train_loss = sum(local_losses) / len(local_losses)
        else:
            train_loss = 100.0

        if args.accelerator.is_local_main_process:
            writer.add_scalar("norm", norm, t)
            writer.add_scalar("train_loss", train_loss, t)

        # ======================= Global aggregation ======================= #
        global_model, aggregated_updates = average_lora_fedplora(
            args,
            global_model,
            local_updates,
        )

        # ======================= Evaluation ======================= #
        global_model = {
            k: v
            for k, v in global_model.items()
            if "quant_state" not in k
            and "absmax" not in k
            and "quant_map" not in k
        }

        net_glob.load_state_dict(global_model)
        net_glob.eval()

        if "mrpc" in args.dataset:
            test_f1, test_acc, test_loss = test_mrpc(
                copy.deepcopy(net_glob),
                dataset_test,
                args,
                t,
            )

            if args.accelerator.is_local_main_process:
                writer.add_scalar("test_acc", test_acc, t)
                writer.add_scalar("test_f1", test_f1, t)

                if test_f1 > best_test_f1:
                    best_test_f1 = test_f1
                    best_test_acc = test_acc
                    metric_keys["Accuracy"] = 1
                    metric_keys["F1"] = 1

            args.logger.info(
                "t {:3d}: train_loss = {:.3f}, norm = {:.3f}, "
                "test_f1 = {:.3f}, test_acc = {:.3f}".format(
                    t,
                    train_loss,
                    norm,
                    test_f1,
                    test_acc,
                ),
                main_process_only=True,
            )

        args.accelerator.wait_for_everyone()

    return (
        best_test_acc,
        best_test_f1,
        best_test_macro_f1,
        best_test_matthews_correlation,
        best_test_micro_f1,
        best_test_rougeL,
    ), metric_keys