import torch
import numpy as np
import json
import os
import copy
from collections import Counter
from tqdm import trange

from util.misc import save_json, AverageMeter, setup_seed
from util.dataset import QueryBudgetDatasetNew

from typing import List, Union, Optional, Tuple, Dict, Any


def uniform_budgets_allocation(
    query_num: int, total_ice_num: int, num_clients: int
) -> List[List[int]]:
    """_summary_

    Args:
        query_num (int): qury number
        total_ice_num (int): total ICE number for ``num_clients`` clients
        num_clients (int): number of clients in the FL setting

    Returns:
        List[List[int]]: list of num_clients lists, each local client's list is a list of integers, each integer represents the local budget for corresponding query
    """
    per_client_budget = int(total_ice_num / num_clients)
    local_budgets = [
        [per_client_budget for qid in range(query_num)] for _ in range(num_clients)
    ]
    return local_budgets


def random_budgets_allocation(
    query_num: int, total_ice_num: int, num_clients: int, seed: int = 0
) -> List[List[int]]:
    """Randomly assign local ICE budgets to each clients, the overall ICE number of each query should be ``total_ice_num``.

    Args:
        query_num (int): qury number
        total_ice_num (int): total ICE number for ``num_clients`` clients
        num_clients (int): number of clients in the FL setting
        seed (int, optional): _description_. Defaults to 0.

    Returns:
        List[List[int]]: list of num_clients lists, each local client's list is a list of integers, each integer represents the local budget for corresponding query
    """
    rng = np.random.default_rng(seed=seed)
    budgets_each_query = (
        []
    )  # for each query, append a list of budget allocation for all clients
    for qid in range(query_num):
        rnd_cids = rng.choice(num_clients, size=total_ice_num, replace=True)
        client_cnt = Counter(rnd_cids)
        cur_query_budgets = [client_cnt.get(cid, 0) for cid in range(num_clients)]
        budgets_each_query.append(cur_query_budgets)

    local_budgets = [[*content] for content in zip(*budgets_each_query)]
    return local_budgets


def train_val_split(
    X: np.ndarray, y: List[List[int]], train_ratio: float = 0.8, seed: int = 0
):
    """Perform train-validation split given training feature ``X`` and the labels ``y``.

    Args:
        X (np.ndarry): with shape ``(num_samples, feature_dim)``.
        y (List[List[int]]): List of List, including ``num_clients`` sub lists, [[m1, ...], [m2, ...], [m3, ...], ...], each sublist is budget values or budget labels for all samples
        train_ratio (float, optional): train split ratio in the whole dataset. Defaults to 0.8.
        seed (int, optional): Random seed. Defaults to 0.

    Returns:
        _type_: _description_
    """
    assert isinstance(X, np.ndarray)
    print(f"X size: {X.shape}, y size: {len(y)}x{len(y[0])}")
    num_clients = len(y)
    total_num = len(y[0])
    train_num = int(total_num * train_ratio)
    if train_num < total_num:
        rng = np.random.default_rng(seed=seed)
        train_set_idxs = sorted(rng.choice(total_num, train_num, replace=False))
        val_set_idxs = [i for i in range(total_num) if i not in train_set_idxs]
        train_X = copy.deepcopy(X[train_set_idxs, :])
        val_X = copy.deepcopy(X[val_set_idxs, :])
        train_y = []
        val_y = []
        for cid in range(num_clients):
            train_y.append(np.array(y[cid])[train_set_idxs].tolist())
            val_y.append(np.array(y[cid])[val_set_idxs].tolist())
    else:
        train_X = copy.deepcopy(X)
        train_y = copy.deepcopy(y)
        val_X = None
        val_y = [[] for _ in range(num_clients)]

    return train_X, train_y, val_X, val_y


def make_equal_range_label(budgets: List[int], max_num: int, num_classes: int):
    """Create budget label by equally seperate the value range of [0, max_num] into ``num_classes`` classes, and assign the budget budget label for each sample.

    Naive method to assign budget label, can lead to severe class imbalance.

    Args:
        budgets (_type_): _description_
        max_num (_type_): _description_
        num_classes (_type_): _description_

    Returns:
        _type_: _description_
    """
    targets = []
    range_val = max_num / num_classes
    for bb in budgets:
        targets.append(int(bb // range_val))
    return targets


def get_basket_range(
    budget_values: List[int], allocations: List[List[int]]
) -> List[List[int]]:
    """Get the value range of each basket based on the assignment of budget buckets, that is, ``allocations``.

    Args:
        budget_values (List[int]): List of budget values for each budget bucket.
        allocations (List[List[int]]): Assignment of each budget bucket to baskets. Can be obtained from :func:``basket_assign_heuristic`` or :func:``

    Returns:
        List[List[int]]: The value range of each basket. For example, ``[[0,2], [3,7], [8,9]]`` means there are 3 baskets, the first basket's value range is 0 to 2, the second's value range is 3 to 7, the third's value range is 8 to 9. There will be no overlap between two value ranges.
    """
    val_ranges = []
    for basket in allocations:
        min_val_idx = basket[0]
        max_val_idx = basket[-1]
        min_val = budget_values[min_val_idx]
        max_val = budget_values[max_val_idx]
        val_ranges.append([min_val, max_val])
    return val_ranges


def cal_basket_sum(buckets: List[int], allocations: List[List[int]]) -> List[int]:
    """Calculate the overall item number in each basket, given bucket allocation ``allocations``.

    Args:
        buckets (List[int]): A list of integers, with each integer indicates the number of items in this bucket.
        allocations (List[List[int]]): contains ``k`` list, each list indicates the bucket index assigned to the current basket.

    Returns:
        List[int]: Total number of items in each basket.
    """
    basket_sums = []
    for basket in allocations:
        basket_sums.append(sum(np.array(buckets)[basket].tolist()))
    return basket_sums


def find_group_indices(ranges: List[List[int]], values: List[int]) -> List[int]:
    """Assign budget basket index for each budget value in ``values`` based on each basket's value ranges ``ranges``.

    Args:
        ranges (List[List[int]]): The value range of each basket. Can be obtained from :func:`get_basket_range`.
        values (List[int]): List of budget values.

    Returns:
        List[int]: List of basket index assigned for each budget value in ``values``. Should have same length as ``values``.
    """
    # Prepare a result list to hold the group indices for each value
    result = []

    # Iterate over each value
    for value in values:
        # Iterate over the range list to find the correct group index
        for index, group_range in enumerate(ranges):
            min_val = group_range[0]
            max_val = group_range[-1]
            if min_val <= value <= max_val:
                # Once the correct group is found, append its index to the result list
                result.append(index)
                break  # Stop checking further since the ranges are non-overlapping and ordered

    return result


def backtrack(
    current_index,
    current_basket_assign,
    current_basket_value,
    num_basket,
    buckets,
    final_assign,
    final_difference,
):
    """Backtrack algorithm for brute-force bucket assignment.

    Args:
        current_index (_type_): _description_
        current_basket_assign (_type_): _description_
        current_basket_value (_type_): _description_
        num_basket (_type_): _description_
        buckets (_type_): _description_
        final_assign (_type_): _description_
        final_difference (_type_): _description_
    """
    # check if using more basket
    if current_basket_assign[-1] + 1 > num_basket:
        return

    # check if current index out of list
    if current_index == len(buckets):
        if current_basket_assign[-1] + 1 == num_basket and final_difference[0] > (
            max(current_basket_value) - min(current_basket_value)
        ):
            final_assign.clear()
            for index in current_basket_assign:
                final_assign.append(index)
            final_difference[0] = max(current_basket_value) - min(current_basket_value)
        return

    # Case 1 (add current_index to same basket)
    current_basket_assign.append(current_basket_assign[-1])
    current_basket_value[-1] += buckets[current_index]
    backtrack(
        current_index + 1,
        current_basket_assign,
        current_basket_value,
        num_basket,
        buckets,
        final_assign,
        final_difference,
    )
    current_basket_value[-1] -= buckets[current_index]
    current_basket_assign.pop()

    # Case 2 (create new basket)
    current_basket_assign.append(current_basket_assign[-1] + 1)
    current_basket_value.append(buckets[current_index])
    backtrack(
        current_index + 1,
        current_basket_assign,
        current_basket_value,
        num_basket,
        buckets,
        final_assign,
        final_difference,
    )
    current_basket_assign.pop()
    current_basket_value.pop()


def basket_assign_brute_force(
    buckets: List[int] = [4, 2, 3, 4, 4, 5, 6], num_basket: int = 4
) -> List[List[int]]:
    """Try to assign the buckets in ``buckets`` into ``num_basket`` baskets in sequence.
    The purpose is to make the overall item number in each basket to be similar with each other.
    This function performs brute force search using backtrack algorithm, so it is fast but can lead to optimal assignment for all cases.
    However, this function can be slow compared with heuristic ver :func:`basket_assign_heuristic`, so only is proper when ``num_basket`` is small.

    Args:
        buckets (List[int], optional): A list of integers, with each integer indicates the number of items in this bucket. Defaults to [4, 2, 3, 4, 4, 5, 6].
        num_basket (int, optional): basket number. Defaults to 4.

    Returns:
        List[List[int]]: contains ``num_basket`` list, each list indicates the bucket index assigned to the current basket.
    """
    # basket_assign_bao(buckets=[4, 2, 3, 4, 4, 5, 6], num_basket=4)
    if len(buckets) < num_basket:
        print("NO WAY TO ASSIGN")
        return

    current_index = 1
    current_basket_assign = [0]
    current_basket_value = [buckets[0]]
    final_assign = []
    final_difference = [1000000000000]
    backtrack(
        current_index,
        current_basket_assign,
        current_basket_value,
        num_basket,
        buckets,
        final_assign,
        final_difference,
    )

    bucket_assign_list = [[] for i in range(num_basket)]
    for i in range(len(final_assign)):
        bucket_assign_list[final_assign[i]].append(i)
    return bucket_assign_list


def basket_assign_heuristic(buckets: List[int], num_basket: int) -> List[List[int]]:
    """Try to assign the buckets in ``buckets`` into ``num_basket`` baskets in sequence.
    The purpose is to make the overall item number in each basket to be similar with each other.
    This function only uses heuristic algorithm, so it is fast but can lead to suboptimal assignment for some cases.


    Args:
        buckets (List[int]): A list of integers, with each integer indicates the number of items in this bucket.
        num_basket (int): basket number

    Returns:
        List[List[int]]: contains ``num_basket`` list, each list indicates the bucket index assigned to the current basket.
    """
    total_sum = sum(buckets)
    target = total_sum / num_basket
    partitions = []
    current_partition = []
    current_sum = 0
    start_index = 0

    for i, num in enumerate(buckets):
        if current_sum + num > target and current_sum != 0:
            if (
                len(partitions) < num_basket - 1
            ):  # Ensure we leave room for remaining baskets
                partitions.append(current_partition)
                current_partition = []
                current_sum = 0
                start_index = i
        current_partition.append(i)
        current_sum += num

    # Add the last partition if not added yety
    partitions.append(current_partition)
    return partitions


def map_budget_values(
    budget_pred_labels,
    budget_label_ranges,
    strategy="max",
    server_ice_num=32,
    use_buffer=True,
    buffer=2,
):
    # assert strategy in [
    #     "max",
    #     "min",
    #     "medium",
    # ], f"strategy can only be in ['max', 'min', 'medium'], rather than '{strategy}'"
    # from budget label to budget number
    num_clients = len(budget_pred_labels)
    query_num = len(budget_pred_labels[0])
    pred_budgets = [[] for _ in range(num_clients)]
    for qid in trange(
        query_num, disable=False, desc="Server side mapping budget values: "
    ):
        cur_q_budgets = []
        for cid in range(num_clients):
            label = budget_pred_labels[cid][qid]
            val_range = budget_label_ranges[cid][label]

            if strategy == "max":
                budget_val = val_range[-1]
            elif strategy == "min":
                budget_val = val_range[0]
            elif strategy == "medium":
                budget_val = int((sum(val_range) + 1) / 2)
            else:
                # strategy is a float number
                ratio = float(strategy)
                budget_val = int(
                    (val_range[-1] - val_range[0]) * ratio + val_range[0] + 1
                )

            if use_buffer:
                budget_val += buffer
            cur_q_budgets.append(budget_val)

        if sum(cur_q_budgets) < server_ice_num:
            # try to randomly assign
            selected_cids = sorted(
                np.random.choice(
                    num_clients, server_ice_num - sum(cur_q_budgets), replace=True
                )
            )
            for cc in selected_cids:
                cur_q_budgets[cc] += 1
        for cid in range(num_clients):
            pred_budgets[cid].append(cur_q_budgets[cid])

    return pred_budgets


def epoch_train_budget_model(
    model, criterion, optimizer, train_loader, num_classes, device=None
):
    loss_stat = AverageMeter()
    acc_stat = AverageMeter()
    per_class_correct = [0 for _ in range(num_classes)]
    per_class_cnt = [1 for _ in range(num_classes)]
    model.train()
    for batch_idx, (embeds, targets) in enumerate(train_loader):
        batch_size = len(targets)
        embeds = embeds.to(device)
        targets = targets.to(device)
        outputs = model(embeds)
        loss = criterion(outputs, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        with torch.no_grad():
            _, predicted = torch.max(outputs, 1)
            correct = torch.sum(predicted.eq(targets)).item()
            acc_stat.update(correct / batch_size, batch_size)
            loss_stat.update(loss.item(), batch_size)

            for c in range(num_classes):
                single_class_batch_correct = (
                    ((predicted == targets) * (targets == c)).float().sum().item()
                )
                per_class_correct[c] += single_class_batch_correct
                single_class_batch_cnt = (targets == c).float().sum().item()
                per_class_cnt[c] += single_class_batch_cnt

    per_class_acc = []
    for c in range(num_classes):
        per_class_acc.append(per_class_correct[c] / per_class_cnt[c])

    return loss_stat.avg, acc_stat.avg, per_class_acc


def eval_budget_model(model, criterion, val_loader, num_classes, device=None):
    loss_stat = AverageMeter()
    acc_stat = AverageMeter()
    per_class_correct = [0 for _ in range(num_classes)]
    per_class_cnt = [1 for _ in range(num_classes)]
    model.eval()

    with torch.no_grad():
        for batch_idx, (embeds, targets) in enumerate(val_loader):
            batch_size = len(targets)
            embeds = embeds.to(device)
            targets = targets.to(device)
            outputs = model(embeds)
            loss = criterion(outputs, targets)

            _, predicted = torch.max(outputs, 1)
            correct = torch.sum(predicted.eq(targets)).item()
            acc_stat.update(correct / batch_size, batch_size)
            loss_stat.update(loss.item(), batch_size)
            for c in range(num_classes):
                single_class_batch_correct = (
                    ((predicted == targets) * (targets == c)).float().sum().item()
                )
                per_class_correct[c] += single_class_batch_correct
                single_class_batch_cnt = (targets == c).float().sum().item()
                per_class_cnt[c] += single_class_batch_cnt

    per_class_acc = []
    for c in range(num_classes):
        per_class_acc.append(per_class_correct[c] / per_class_cnt[c])

    return loss_stat.avg, acc_stat.avg, per_class_acc
