# Copyright 2021 Peng Cheng Laboratory (http://www.szpclab.com/) and FedLab Authors (smilelab.group)

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

import os
import json
import pynvml
import random
import numpy as np
from collections import Counter
import torch.nn as nn
from .metrics import ECELoss, SCELoss
np.random.seed(886) 
import random
random.seed(886)
import pdb
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True


class AverageMeter(object):
    """Record metrics information"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0.0
        self.avg = 0.0
        self.sum = 0.0
        self.count = 0.0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def evaluate(model, criterion, test_loader):
    """Evaluate classify task model accuracy.

    Returns:
        (loss.sum, acc.avg)
    """
    model.eval()
    gpu = next(model.parameters()).device

    loss_ = AverageMeter()
    acc_ = AverageMeter()

    pred_list = []
    target_list = []
    prob_list = []

    all_targets = None
    all_outputs = None

    with torch.no_grad():
        # for inputs, labels in test_loader:

        for inputs, labels,_ in test_loader: # for femnist
            # for femnist
            inputs = torch.tensor(inputs,dtype=torch.float32).reshape(-1, 1, 28, 28)
            labels = torch.tensor(labels, dtype=torch.long)
            # for femnist

            batch_size = len(labels)
            inputs = inputs.to(gpu)
            labels = labels.to(gpu)

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            


            _, predicted = torch.max(outputs, 1)
            # prob_list.extend(outputs.tolist())
            prob_list.extend(nn.Softmax(dim=1)(outputs).tolist())
            
            pred_list.extend(predicted.tolist())
            target_list.extend(labels.tolist())
            
            loss_.update(loss.item(), batch_size)
            acc_.update(torch.sum(predicted.eq(labels)).item() / batch_size, batch_size)

            
            targets = labels.cpu().numpy()
            outputs = outputs.cpu().numpy()

            if all_targets is None:
                all_outputs = outputs
                all_targets = targets
            else:
                all_targets = np.concatenate([all_targets, targets], axis=0)
                all_outputs = np.concatenate([all_outputs, outputs], axis=0)
        # print(ECELoss().loss(all_outputs, all_targets, n_bins=15))
        # print(SCELoss().loss(all_outputs, all_targets, n_bins=15))
        # pdb.set_trace()
    return loss_.avg, acc_.avg, pred_list, target_list, prob_list, all_targets, all_outputs


def read_config_from_json(json_file: str, user_name: str):
    """Read config from `json_file` to get config for `user_name`

    Args:
        json_file (str): path for json_file
        user_name (str): read config for this user, it can be 'server' or 'client_id'

    Returns:
        a tuple with ip, port, world_size, rank about user with `user_name`

    Examples:
        read_config_from_json('../../../tests/data/config.json', 'server')

    Notes:
        config.json example as follows
        {
          "server": {
            "ip" : "127.0.0.1",
            "port": "3002",
            "world_size": 3,
            "rank": 0
          },
          "client_0": {
            "ip": "127.0.0.1",
            "port": "3002",
            "world_size": 3,
            "rank": 1
          },
          "client_1": {
            "ip": "127.0.0.1",
            "port": "3002",
            "world_size": 3,
            "rank": 2
          }
        }
    """
    with open(json_file) as f:
        config = json.load(f)
    config_info = config[user_name]
    return (
        config_info["ip"],
        config_info["port"],
        config_info["world_size"],
        config_info["rank"],
    )


def get_best_gpu():
    """Return gpu (:class:`torch.device`) with largest free memory."""
    assert torch.cuda.is_available()
    pynvml.nvmlInit()
    deviceCount = pynvml.nvmlDeviceGetCount()

    if "CUDA_VISIBLE_DEVICES" in os.environ.keys() is not None:
        cuda_devices = [
            int(device) for device in os.environ["CUDA_VISIBLE_DEVICES"].split(",")
        ]
    else:
        cuda_devices = range(deviceCount)

    assert max(cuda_devices) < deviceCount
    deviceMemory = []
    for i in cuda_devices:
        handle = pynvml.nvmlDeviceGetHandleByIndex(i)
        mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
        deviceMemory.append(mem_info.free)
    deviceMemory = np.array(deviceMemory, dtype=np.int64)
    best_device_index = np.argmax(deviceMemory)
    return torch.device("cuda:%d" % (best_device_index))


def partition_report(targets, data_indices, class_num=None, verbose=True, file=None):
    """Generate data partition report for clients in ``data_indices``.

    Generate data partition report for each client according to ``data_indices``, including
    ratio of each class and dataset size in current client. Report can be printed in screen or into
    file. The output format is comma-separated values which can be read by :func:`pandas.read_csv`
    or :func:`csv.reader`.

    Args:
        targets (list or numpy.ndarray): Targets for all data samples, with each element is in range of ``0`` to ``class_num-1``.
        data_indices (dict): Dict of ``client_id: [data indices]``.
        class_num (int, optional): Total number of classes. If set to ``None``, then ``class_num = max(targets) + 1``.
        verbose (bool, optional): Whether print data partition report in screen. Default as ``True``.
        file (str, optional): Output file name of data partition report. If ``None``, then no output in file. Default as ``None``.

    Examples:
        First generate synthetic data labels and data partition to obtain ``data_indices``
        (``{ client_id: sample indices}``):

        >>> sample_num = 15
        >>> class_num = 4
        >>> clients_num = 3
        >>> num_per_client = int(sample_num/clients_num)
        >>> labels = np.random.randint(class_num, size=sample_num)  # generate 15 labels, each label is 0 to 3
        >>> rand_per = np.random.permutation(sample_num)
        >>> # partition synthetic data into 3 clients
        >>> data_indices = {0: rand_per[0:num_per_client],
        ...                 1: rand_per[num_per_client:num_per_client*2],
        ...                 2: rand_per[num_per_client*2:num_per_client*3]}

        Check ``data_indices`` may look like:

        >>> data_indices
        {0: array([8, 6, 5, 7, 2]),
         1: array([ 3, 10, 14,  4,  1]),
         2: array([13,  9, 12, 11,  0])}

        Now generate partition report for each client and each class:

        >>> partition_report(labels, data_indices, class_num=class_num, verbose=True, file=None)
        Class frequencies:
        client,class0,class1,class2,class3,Amount
        Client   0,0.200,0.00,0.200,0.600,5
        Client   1,0.400,0.200,0.200,0.200,5
        Client   2,0.00,0.400,0.400,0.200,5

    """
    if not verbose and file is None:
        print("No partition report generated")
        return

    if not isinstance(targets, np.ndarray):
        targets = np.array(targets)

    if not class_num:
        class_num = max(targets) + 1

    sorted_cid = sorted(data_indices.keys())  # sort client id in ascending order

    header_line = "Class frequencies:"
    col_name = "client," + ",".join([f"class{i}" for i in range(class_num)]) + ",Amount"

    if verbose:
        print(header_line)
        print(col_name)
    if file is not None:
        reports = [header_line, col_name]
    else:
        reports = None

    for client_id in sorted_cid:
        indices = data_indices[client_id]
        client_targets = targets[indices]
        client_sample_num = len(indices)  # total number of samples of current client
        client_target_cnt = Counter(client_targets)  # { cls1: num1, cls2: num2, ... }

        report_line = (
            f"Client {client_id:3d},"
            + ",".join(
                [
                    f"{client_target_cnt[cls] / client_sample_num:.3f}"
                    if cls in client_target_cnt
                    else "0.00"
                    for cls in range(class_num)
                ]
            )
            + f",{client_sample_num}"
        )
        if verbose:
            print(report_line)
        if file is not None:
            reports.append(report_line)

    if file is not None:
        fh = open(file, "w")
        fh.write("\n".join(reports))
        fh.close()
