import os
import numpy as np
import torch
import os
from openpyxl import Workbook, load_workbook
from openpyxl.utils.dataframe import dataframe_to_rows
from datetime import datetime

def count_parameters(model, trainable=False):
    if trainable:
        return sum(p.numel() for p in model.parameters() if p.requires_grad)
    return sum(p.numel() for p in model.parameters())


def tensor2numpy(x):
    return x.cpu().data.numpy() if x.is_cuda else x.data.numpy()


def target2onehot(targets, n_classes):
    onehot = torch.zeros(targets.shape[0], n_classes).to(targets.device)
    onehot.scatter_(dim=1, index=targets.long().view(-1, 1), value=1.0)
    return onehot

def convert_time(seconds):
    m, s = divmod(seconds, 60)
    h, m = divmod(m, 60)
    return "%d:%02d:%02d" % (h, m, s)

def makedirs(path):
    if not os.path.exists(path):
        os.makedirs(path)


def accuracy(y_pred, y_true, nb_old, increment=10):
    # nb_old 是已知/过去task上遇到的类别的总量
    assert len(y_pred) == len(y_true), "Data length error."
    all_acc = {}
    # 取出的对应类别的所有测试数据的准确率（%）
    all_acc["total"] = np.around(
        (y_pred == y_true).sum() * 100 / len(y_true), decimals=2
    )

    # Grouped accuracy
    # 计算到当前task为止，前面每一个增量task上的一组类别的分类准确率
    for class_id in range(0, np.max(y_true), increment):
        idxes = np.where(
            np.logical_and(y_true >= class_id, y_true < class_id + increment)
        )[0]    # 0是取出索引
        # 例如： 0-9 表示第一个incremental上，0-9个类别的预测准确率
        label = "{}-{}".format(
            str(class_id).rjust(2, "0"), str(class_id + increment - 1).rjust(2, "0")    # .rjust(2, "0") 表示在小于2位长度的字符串左侧补0，使其长度为2
        )
        # 计算特定一组类别的准确率
        all_acc[label] = np.around(
            (y_pred[idxes] == y_true[idxes]).sum() * 100 / len(idxes), decimals=2
        )

    # Old accuracy
    # 计算旧类别的准确率
    idxes = np.where(y_true < nb_old)[0]
    all_acc["old"] = (
        0 if len(idxes) == 0
        else np.around(
            (y_pred[idxes] == y_true[idxes]).sum() * 100 / len(idxes), decimals=2
        )
    )

    # New accuracy
    idxes = np.where(y_true >= nb_old)[0]
    all_acc["new"] = np.around(
        (y_pred[idxes] == y_true[idxes]).sum() * 100 / len(idxes), decimals=2
    )

    return all_acc


def split_images_labels(imgs):
    # split trainset.imgs in ImageFolder
    """
    把读取的图像与其对应的标签分开，返回两个列表
    (PIL.Image.open('train_dir/cats/cat_1.jpg'), 0)
    (PIL.Image.open('train_dir/cats/cat_2.jpg'), 0)
    (PIL.Image.open('train_dir/dogs/dog_1.jpg'), 1)
    (PIL.Image.open('train_dir/dogs/dog_2.jpg'), 1)
    """
    images = []
    labels = []
    for item in imgs:
        images.append(item[0])
        labels.append(item[1])

    return np.array(images), np.array(labels)

def list2dict(list):
    dict = {}
    for l in list:
        s = l.split(' ')
        id = int(s[0])
        cls = s[1]
        if id not in dict.keys():
            dict[id] = cls
        else:
            raise EOFError('The same ID can only appear once')
    return dict

def text_read(file):
    with open(file, 'r') as f:
        lines = f.readlines()
        for i, line in enumerate(lines):
            lines[i] = line.strip('\n')
    return lines




def save_results_to_excel(dataset_name, method_name, incremental_num, results, runing_time='', device='', note='', seed=''):
    """
    将神经网络训练结果保存到一个 Excel 文件中，每个 Sheet 对应一个评价指标名称。
    Excel 文件位于 incremental_num 命名的文件夹中，存储在数据集命名的文件夹下。
    每行数据前插入当前时间，第一行写字段名称。
    """
    # 创建指定的文件夹结构
    base_dir = os.path.join(os.getcwd(), 'results', dataset_name, '{}'.format(str(incremental_num)))
    os.makedirs(base_dir, exist_ok=True)

    # 定义 Excel 文件路径
    excel_path = os.path.join(base_dir, f"{method_name}.xlsx")

    # 获取当前时间
    current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")

    # 如果 Excel 文件已存在，加载它；否则，创建新的文件
    if os.path.exists(excel_path):
        book = load_workbook(excel_path)
    else:
        book = Workbook()  # 创建新的工作簿
        book.remove(book.active)  # 移除默认创建的空 sheet

    # 遍历每个评价指标，保存结果到对应的 Sheet
    for metric, data in results.items():
        if metric not in book.sheetnames:
            # 如果该 sheet 不存在，创建新的 sheet
            sheet = book.create_sheet(metric)
            # 写入标题行
            sheet.append(['Method', 'Timestamp', 'seed', 'Parameters'] + [f'task_{i}' for i in range(len(data[0][2]))] +['inc_acc','grouped_top1_acc', 'running_time','device','note']) 
        else:
            # 如果该 sheet 已经存在，获取该 sheet
            sheet = book[metric]

        # 插入数据
        
        for entry in data:
            if entry != []:
                # 在每一条数据前插入时间戳
                row = [entry[0], current_time, seed, entry[1]] + list(entry[2])   + [str(entry[3]),str(entry[4]), runing_time, device, note]
                sheet.append(row)

    # 保存 Excel 文件
    book.save(excel_path)

def get_device_name(device_type):
    # device_type: [-1,0,1]  -1代表cpu，0代表gpu0，1代表gpu1
    device_names = []
    
    for device in device_type:
        if device == -1 or device == "-1":  # 如果设备类型是 -1，则表示使用 CPU
            device_names.append("CPU")
        else:  # 否则，获取指定 GPU 的名称
            try:
                device_names.append(torch.cuda.get_device_name(int(device)))
            except Exception as e:
                # 如果获取 GPU 名称出错，输出错误信息
                pass
                # device_names.append(f"Error getting device name for cuda:{device}, {str(e)}")
    
    return device_names


# 示例调用
if __name__ == "__main__":
    # dataset_name = "CIFAR-10"
    # method_name = "MyNeuralNet"
    # incremental_num = "10"  # 示例传入的增量编号
    # grouped_top1_acc = 0.01
    # results = {
    #     'Accuracy': [
    #         ('Method1', 'lr=0.01, batch=32','seed', [0.85, 0.87, 0.88], 'inc_acc', grouped_top1_acc,'run time','device','note'),
    #         ('Method2', 'lr=0.001, batch=64','seed', [0.83, 0.84, 0.85], 'inc_acc', grouped_top1_acc,'run time','device','note')
    #     ],
    #     'Loss': [
    #         ('Method1', 'lr=0.01, batch=32','seed', [0.35, 0.34, 0.33], 'inc_acc', grouped_top1_acc,'run time','device','note'),
    #         ('Method2', 'lr=0.001, batch=64','seed', [0.40, 0.38, 0.37], 'inc_acc', grouped_top1_acc,'run time','device','note')
    #     ]
    # }
    # save_results_to_excel(dataset_name, method_name, incremental_num, results, runing_time='123', device='', note='')
    device_name = get_device_name([0,1])
    print(device_name)