import numpy as np
import torch
import torch.nn as nn
import numpy
import xlsxwriter

from collections import OrderedDict

def avg_tensor(list, dim=0):
    t = torch.stack(list)
    mean = torch.mean(t, dim=dim)
    return mean

def avg_dict(dict):
    output = OrderedDict()
    for key, value in dict.items():
        mean = avg_tensor(value)
        output[key] = mean
    return output

def avg_dict_2(dict):
    output = OrderedDict()
    for key, value in dict.items():
        mean = avg_tensor(value).mean()
        output[key] = mean

    output = {k: v for k, v in sorted(output.items(), key=lambda item: item[1])}
    return output

def flooding_lower_each_task(dict, flood_level=0.0):
    output = OrderedDict()
    for key, value in dict.items():
        v = torch.stack(value)
        mask = (v < flood_level)
        ans = mask.sum(dim=0)
        output[key] = ans

    return output

def flooding_lower(dict, flood_level=0.0):
    output = OrderedDict()
    for key, value in dict.items():
        v = torch.stack(value)
        mask = (v < flood_level)
        ans = mask.sum(dim=0)
        output[key] = ans.sum()

    output = {k: v for k, v in sorted(output.items(), key=lambda item: item[1])}
    return output

def flooding_lower_balance(dict, flood_level=0.0):
    output = OrderedDict()
    for key, value in dict.items():
        v = torch.stack(value)
        n_task = v.size(1)
        mask = (v < flood_level)
        mask = mask.sum(dim=1)

        conflict = (mask == n_task)
        times = conflict.sum() / conflict.size(0)

        output[key] = times

    output = {k: v for k, v in sorted(output.items(), key=lambda item: item[1], reverse=True)}
    return output

def quantile_dict(dict, q, interpolation='linear'):
    output = OrderedDict()
    for key, value in dict.items():
        mean = avg_tensor(value, dim=1)
        quantile_value = torch.quantile(mean, q, interpolation=interpolation)
        output[key] = quantile_value

    output = {k: v for k, v in sorted(output.items(), key=lambda item: item[1])}
    return output

def print_dict(dict):
    for key, value in dict.items():
        print(f'{key}: {value}')

def topK_bonus_avg(task_diff):
    n_task = list(task_diff.values())[0][0].size(0)
    diff_dict = [OrderedDict() for i in range(n_task)]
    dict_avg = OrderedDict()

    for i in range(n_task):
        for key, value in task_diff.items():
            l = torch.stack(value)
            data = l[:, i]
            avg = data.sum() / data.size(0)
            diff_dict[i][key] = avg

        diff_dict[i] = {k: v for k, v in sorted(diff_dict[i].items(), key=lambda item: item[1], reverse=True)}

    # calculate the average
    for key in diff_dict[0].keys():
        v_l = [diff_dict[i][key].item() for i in range(n_task)]
        mean = np.mean(v_l)
        dict_avg[key] = mean

    dict_avg = {k: v for k, v in sorted(dict_avg.items(), key=lambda item: item[1], reverse=True)}

    return diff_dict, dict_avg


def topK_prob(task_diff):
    diff_dict = OrderedDict()
    diff_dict_2 = OrderedDict()

    for key, value in task_diff.items():
        l = torch.stack(value)
        n_task = l.size(1)
        gt = (l > 0)

        diff_dict_2[key] = gt.sum() / (gt.size(0) * n_task)

        gt = gt.sum(dim=1)
        gt = (gt == n_task)
        prob = gt.sum() / gt.size(0)
        diff_dict[key] = prob

    diff_dict = {k: v for k, v in sorted(diff_dict.items(), key=lambda item: item[1], reverse=True)}

    diff_dict_2 = {k: v for k, v in sorted(diff_dict_2.items(), key=lambda item: item[1], reverse=True)}

    print('---------------------topK_prob-----------------------')
    print_dict(diff_dict)

    #
    # print('-------------topK_prob---------------')
    # print(diff_dict)

    return diff_dict, diff_dict_2

def topK_prob_task(task_diff):
    n_task = list(task_diff.values())[0][0].size(0)

    dict_avg = OrderedDict()
    dict_std = OrderedDict()
    diff_dict = [OrderedDict() for i in range(n_task)]
    for i in range(n_task):
        for key, value in task_diff.items():
            l = torch.stack(value)
            data = l[:, i]
            gt = (data > 0)
            prob = gt.sum() / gt.size(0)
            diff_dict[i][key] = prob

        diff_dict[i] = {k: v for k, v in sorted(diff_dict[i].items(), key=lambda item: item[1], reverse=True)}

        print(f'---------------------topK prob of task {i}-----------------------')
        print_dict(diff_dict[i])

    # calculate the average
    for key in diff_dict[0].keys():
        v_l = [diff_dict[i][key].item() for i in range(n_task)]
        mean = np.mean(v_l)
        std = np.std(v_l)
        dict_avg[key] = mean
        dict_std[key] = std

    dict_avg = {k: v for k, v in sorted(dict_avg.items(), key=lambda item: item[1], reverse=True)}
    print(f'---------------------topK prob of task avg-----------------------')
    print_dict(dict_avg)

    dict_std = {k: v for k, v in sorted(dict_std.items(), key=lambda item: item[1], reverse=True)}
    print(f'---------------------topK prob of task std----------------------')
    print_dict(dict_std)

    return diff_dict, dict_avg, dict_std

def norm_dict(dictionary):
    avg = sum(list(dictionary.values())) / len(dictionary.keys())
    for key in dictionary.keys():
        dictionary[key] = dictionary[key] / avg

def topK_prob_task_norm(task_diff):
    n_task = list(task_diff.values())[0][0].size(0)

    dict_avg = OrderedDict()
    dict_std = OrderedDict()
    diff_dict = [OrderedDict() for i in range(n_task)]
    for i in range(n_task):
        for key, value in task_diff.items():
            l = torch.stack(value)
            data = l[:, i]
            gt = (data > 0)
            prob = gt.sum() / gt.size(0)
            diff_dict[i][key] = prob


        norm_dict(diff_dict[i])
        diff_dict[i] = {k: v for k, v in sorted(diff_dict[i].items(), key=lambda item: item[1], reverse=True)}

        print(f'---------------------topK prob of task {i}-----------------------')
        print_dict(diff_dict[i])

    # calculate the average
    for key in diff_dict[0].keys():
        v_l = [diff_dict[i][key].item() for i in range(n_task)]
        mean = np.mean(v_l)
        std = np.std(v_l)
        dict_avg[key] = mean
        dict_std[key] = std

    dict_avg = {k: v for k, v in sorted(dict_avg.items(), key=lambda item: item[1], reverse=True)}
    print(f'---------------------topK prob of task norm avg-----------------------')
    print_dict(dict_avg)

    dict_std = {k: v for k, v in sorted(dict_std.items(), key=lambda item: item[1], reverse=True)}
    print(f'---------------------topK prob of task norm std----------------------')
    print_dict(dict_std)

    return dict_avg, dict_std

def topK_prob_range(task_diff, topK, l, h, n_epoch):
    diff_dict = OrderedDict()

    for key, value in task_diff.items():
        L = torch.stack(value)
        n_batches = L.size(0) / n_epoch
        L = L[int(l * n_batches): int(h * n_batches)]
        n_task = L.size(1)
        gt = (L > 0)
        gt = gt.sum(dim=1)
        gt = (gt > n_task - 1)
        prob = gt.sum() / gt.size(0)
        diff_dict[key] = prob

    diff_dict = {k: v for k, v in sorted(diff_dict.items(), key=lambda item: item[1], reverse=True)}

    print(f'---------------------topK_prob_range [{l}, {h}]-----------------------')
    print_dict(diff_dict)

    output = list(diff_dict.keys())[:topK]
    #
    # print('-------------topK_prob---------------')
    # print(diff_dict)
    return output

def topK_prob2(task_diff, topK):
    diff_dict = OrderedDict()
    for key, value in task_diff.items():
        l = torch.stack(value)
        n_task = l.size(1)
        gt = (l > 0)
        gt = gt.sum(dim=1)
        gt = (gt > n_task - 1)
        prob = gt.sum() / gt.size(0)
        diff_dict[key] = prob

    diff_dict = {k: v for k, v in sorted(diff_dict.items(), key=lambda item: item[1], reverse=True)}

    print('---------------------topK_prob-----------------------')
    print_dict(diff_dict)

    output = list(diff_dict.keys())[:topK]
    #
    # print('-------------topK_prob---------------')
    # print(diff_dict)
    return output

def topK_value(task_diff, topK):
    diff_dict = OrderedDict()


    for key, value in task_diff.items():
        l = torch.stack(value)
        score = l.mean()
        diff_dict[key] = score

    diff_dict = {k: v for k, v in sorted(diff_dict.items(), key=lambda item: item[1])}

    output = list(diff_dict.keys())[:topK]

    print('-------------topK_value---------------')
    print_dict(diff_dict)

    return output

def write_task_diff(task_diff, worksheet_task_diff):
    n_task = list(task_diff.values())[0][0].size(0)

    total_prob_dict, exist_prob_dict = topK_prob(task_diff)
    task_prob, task_prob_avg, task_prob_std = topK_prob_task(task_diff)
    task_prob_norm_avg, task_prob_norm_std = topK_prob_task_norm(task_diff)

    cur_row = 1
    cur_col = 0
    worksheet_task_diff.write(0, cur_col, 'total_prob_dict')
    worksheet_task_diff.write(0, cur_col + 1, 'value')
    for key, value in total_prob_dict.items():
        worksheet_task_diff.write(cur_row, cur_col, key)
        worksheet_task_diff.write(cur_row, cur_col + 1, value)
        cur_row += 1

    cur_row = 1
    cur_col += 4
    worksheet_task_diff.write(0, cur_col, 'task_prob_exist')
    worksheet_task_diff.write(0, cur_col + 1, 'value')
    for key, value in exist_prob_dict.items():
        worksheet_task_diff.write(cur_row, cur_col, key)
        worksheet_task_diff.write(cur_row, cur_col + 1, value)
        cur_row += 1

    for i in range(n_task):
        cur_row = 1
        cur_col += 4
        worksheet_task_diff.write(0, cur_col, f'task_prob_task_{i}')
        worksheet_task_diff.write(0, cur_col + 1, 'value')
        for key, value in task_prob[i].items():
            worksheet_task_diff.write(cur_row, cur_col, key)
            worksheet_task_diff.write(cur_row, cur_col + 1, value)
            cur_row += 1

    cur_row = 1
    cur_col += 4
    worksheet_task_diff.write(0, cur_col, 'task_prob_avg')
    worksheet_task_diff.write(0, cur_col + 1, 'value')
    for key, value in task_prob_avg.items():
        worksheet_task_diff.write(cur_row, cur_col, key)
        worksheet_task_diff.write(cur_row, cur_col + 1, value)
        cur_row += 1

    cur_row = 1
    cur_col += 4
    worksheet_task_diff.write(0, cur_col, 'task_prob_std')
    worksheet_task_diff.write(0, cur_col + 1, 'value')
    for key, value in task_prob_std.items():
        worksheet_task_diff.write(cur_row, cur_col, key)
        worksheet_task_diff.write(cur_row, cur_col + 1, value)
        cur_row += 1

    cur_row = 1
    cur_col += 4
    worksheet_task_diff.write(0, cur_col, 'task_prob_norm_avg')
    worksheet_task_diff.write(0, cur_col + 1, 'value')
    for key, value in task_prob_norm_avg.items():
        worksheet_task_diff.write(cur_row, cur_col, key)
        worksheet_task_diff.write(cur_row, cur_col + 1, value)
        cur_row += 1

    cur_row = 1
    cur_col += 4
    worksheet_task_diff.write(0, cur_col, 'task_prob_norm_std')
    worksheet_task_diff.write(0, cur_col + 1, 'value')
    for key, value in task_prob_norm_std.items():
        worksheet_task_diff.write(cur_row, cur_col, key)
        worksheet_task_diff.write(cur_row, cur_col + 1, value)
        cur_row += 1

def write_task_diff_avg(task_diff, worksheet_task_diff):
    n_task = list(task_diff.values())[0][0].size(0)
    diff_dict, diff_avg = topK_bonus_avg(task_diff)
    cur_col = -4

    for i in range(n_task):
        cur_row = 1
        cur_col += 4
        worksheet_task_diff.write(0, cur_col, f'bonus_avg_task_{i}')
        worksheet_task_diff.write(0, cur_col + 1, 'value')
        for key, value in diff_dict[i].items():
            worksheet_task_diff.write(cur_row, cur_col, key)
            worksheet_task_diff.write(cur_row, cur_col + 1, value)
            cur_row += 1

    cur_row = 1
    cur_col += 4
    worksheet_task_diff.write(0, cur_col, 'bonus_avg')
    worksheet_task_diff.write(0, cur_col + 1, 'value')
    for key, value in diff_avg.items():
        worksheet_task_diff.write(cur_row, cur_col, key)
        worksheet_task_diff.write(cur_row, cur_col + 1, value)
        cur_row += 1

def write_lw_task_avg(cos, worksheet_cos):
    avg_angle = avg_dict(cos)
    n_task = list(avg_angle.values())[0].view(-1).size(0)
    output = [OrderedDict() for i in range(n_task)]
    for i in range(n_task):
        for key, value in avg_angle.items():
            output[i][key] = value[i]

    for i in range(n_task):
        output[i] = {k: v for k, v in sorted(output[i].items(), key=lambda item: item[1])}

    cur_col = 0
    for i in range(n_task):
        cur_row = 1
        worksheet_cos.write(0, cur_col, f'avg_angle_t{i}')
        worksheet_cos.write(0, cur_col + 1, 'value')
        for key, value in output[i].items():
            worksheet_cos.write(cur_row, cur_col, key)
            worksheet_cos.write(cur_row, cur_col + 1, value)
            cur_row += 1

        cur_col += 4

def write_lw_task_flood(cos, worksheet_cos, flood=0.0):
    flood_data = flooding_lower_each_task(cos, flood_level=flood)
    n_task = list(flood_data.values())[0].view(-1).size(0)
    output = [OrderedDict() for i in range(n_task)]

    for i in range(n_task):
        for key, value in flood_data.items():
            if n_task == 1:
                output[i][key] = value
            else:
                output[i][key] = value[i]

    for i in range(n_task):
        output[i] = {k: v for k, v in sorted(output[i].items(), key=lambda item: item[1])}

    cur_col = 0
    for i in range(n_task):
        cur_row = 1
        worksheet_cos.write(0, cur_col, f'flood{flood}_t{i}')
        worksheet_cos.write(0, cur_col + 1, 'value')
        for key, value in output[i].items():
            worksheet_cos.write(cur_row, cur_col, key)
            worksheet_cos.write(cur_row, cur_col + 1, value)
            cur_row += 1

        cur_col += 4

def write_lw(cos, worksheet_cos, flood=0.0):
    avg_angle = avg_dict_2(cos)
    flood_angle = flooding_lower(cos, flood_level=flood)
    flood_conflict = flooding_lower_balance(cos, flood_level=flood)
    q_angle = quantile_dict(cos, q=0.25)

    cur_row = 1
    cur_col = 0
    worksheet_cos.write(0, cur_col, 'avg_angle')
    worksheet_cos.write(0, cur_col + 1, 'value')
    for key, value in avg_angle.items():
        worksheet_cos.write(cur_row, cur_col, key)
        worksheet_cos.write(cur_row, cur_col + 1, value)
        cur_row += 1

    cur_row = 1
    cur_col += 4
    worksheet_cos.write(0, cur_col, f'flood_angle_{flood}')
    worksheet_cos.write(0, cur_col + 1, 'value')
    for key, value in flood_angle.items():
        worksheet_cos.write(cur_row, cur_col, key)
        worksheet_cos.write(cur_row, cur_col + 1, value)
        cur_row += 1

    cur_row = 1
    cur_col += 4
    worksheet_cos.write(0, cur_col, 'flood_conflict_meanwhile')
    worksheet_cos.write(0, cur_col + 1, 'value')
    for key, value in flood_conflict.items():
        worksheet_cos.write(cur_row, cur_col, key)
        worksheet_cos.write(cur_row, cur_col + 1, value)
        cur_row += 1

    cur_row = 1
    cur_col += 4
    worksheet_cos.write(0, cur_col, 'q_angle')
    worksheet_cos.write(0, cur_col + 1, 'value')
    for key, value in q_angle.items():
        worksheet_cos.write(cur_row, cur_col, key)
        worksheet_cos.write(cur_row, cur_col + 1, value)
        cur_row += 1

def split_dict(d, n_epoch, lb, hb):
    n_batches = len(list(d.values())[0]) / n_epoch
    lb = int(lb * n_batches)
    hb = int(hb * n_batches)

    for key, value in d.items():
        d[key] = value[lb:hb]

if __name__ == '__main__':
    device = torch.device('cuda:0')
    dataset = 'mnist'
    submethod = 'mgd'
    Optimizer = 'SGD'
    epoch = 40
    seed = 0
    flood = -0.1
    lb = 0
    hb = 40
    workbook = xlsxwriter.Workbook(f'./saved/{seed}{submethod}{Optimizer}_{dataset}_{Optimizer}_{epoch}_fd{flood}_{lb}_{hb}_v2.xlsx')

    # --------------------------------------------------------------------------------------------------
    # name = 'task_diff_data'
    # path = f'./saved/{seed}{submethod}{Optimizer}_{epoch}_{name}_{dataset}.pt'
    # dictionary = torch.load(path, map_location=device)
    # task_diff = dictionary['task_diff']
    # split_dict(task_diff, epoch, lb, hb)
    # worksheet_task_diff = workbook.add_worksheet(f'{name}')
    #
    # write_task_diff(task_diff, worksheet_task_diff)
    #
    # worksheet_task_diff = workbook.add_worksheet(f'{name}_avg')
    #
    # write_task_diff_avg(task_diff, worksheet_task_diff)
    #
    # #--------------------------------------------------------------------------------------------------
    # name = 'diff_data'
    # path = f'./saved/{seed}{submethod}{Optimizer}_{epoch}_{name}_{dataset}.pt'
    # dictionary = torch.load(path, map_location=device)
    # diff = dictionary['diff']
    # split_dict(diff, epoch, lb, hb)
    # worksheet_diff = workbook.add_worksheet(f'{name}')
    # write_task_diff(diff, worksheet_diff)
    #
    # worksheet_diff = workbook.add_worksheet(f'{name}_avg')
    # write_task_diff_avg(diff, worksheet_diff)
    #

    # # --------------------------------------------------------------------------------------------------
    name = 'lw_cos'
    path = f'./saved/{seed}{submethod}{Optimizer}_{epoch}_{name}_{dataset}.pt'
    dictionary = torch.load(path, map_location=device)
    cos = dictionary['cos']
    split_dict(cos, epoch, lb, hb)

    worksheet_cos = workbook.add_worksheet(f'{name}')
    write_lw(cos, worksheet_cos, flood=flood)

    # --------------------------------------------------------------------------------------------------
    # name = 'lw_cos'
    # path = f'./saved/{seed}_{epoch}_{name}_{dataset}.pt'
    # dictionary = torch.load(path, map_location=device)
    # cos = dictionary['cos']
    worksheet_cos = workbook.add_worksheet(f'{name}_task_avg')

    write_lw_task_avg(cos, worksheet_cos)

    worksheet_flood_each_task = workbook.add_worksheet(f'{name}_flood_each_task')
    write_lw_task_flood(cos, worksheet_flood_each_task, flood=flood)

    # --------------------------------------------------------------------------------------------------
    name = 'lw_task_cos'
    path = f'./saved/{seed}{submethod}{Optimizer}_{epoch}_{name}_{dataset}.pt'
    dictionary = torch.load(path, map_location=device)
    task_cos = dictionary['task_cos']
    split_dict(task_cos, epoch, lb, hb)

    worksheet_task_cos = workbook.add_worksheet(f'{name}')
    write_lw(task_cos, worksheet_task_cos, flood=flood)

    # --------------------------------------------------------------------------------------------------
    # --------------------------------------------------------------------------------------------------
    # name = 'lw_task_cos'
    # path = f'./saved/{seed}_{epoch}_{name}_{dataset}.pt'
    # dictionary = torch.load(path, map_location=device)
    # cos = dictionary['task_cos']
    worksheet_cos = workbook.add_worksheet(f'{name}_task_avg')
    write_lw_task_avg(task_cos, worksheet_cos)

    worksheet_flood_each_task = workbook.add_worksheet(f'{name}_flood_each_task')
    write_lw_task_flood(task_cos, worksheet_flood_each_task, flood=flood)



    workbook.close()

    print('finished')

# if __name__ == '__main__':
#
#     device = torch.device('cuda:0')
#     dataset = 'mnist'
#     submethod = 'cagrad'
#     Optimizer = 'Adam'
#     epoch = 100
#     seed = 1997
#     flood = 0.0
#     lb = 0
#     hb = 20
#     workbook = xlsxwriter.Workbook(f'./saved/{seed}{submethod}{Optimizer}_{dataset}_{Optimizer}_{epoch}_fd{flood}_{lb}_{hb}.xlsx')
#
#     # --------------------------------------------------------------------------------------------------
#     name = 'task_diff_data'
#     path = f'./saved/{seed}{submethod}{Optimizer}_{epoch}_{name}_{dataset}.pt'
#     dictionary = torch.load(path, map_location=device)
#     task_diff = dictionary['task_diff']
#     split_dict(task_diff, epoch, lb, hb)
#     worksheet_task_diff = workbook.add_worksheet(f'{name}')
#
#     write_task_diff(task_diff, worksheet_task_diff)
#
#     #--------------------------------------------------------------------------------------------------
#     name = 'diff_data'
#     path = f'./saved/{seed}{submethod}{Optimizer}_{epoch}_{name}_{dataset}.pt'
#     dictionary = torch.load(path, map_location=device)
#     diff = dictionary['diff']
#     split_dict(diff, epoch, lb, hb)
#     worksheet_diff = workbook.add_worksheet(f'{name}')
#
#     write_task_diff(diff, worksheet_diff)
#
#     #--------------------------------------------------------------------------------------------------
#     name = 'lw_cos'
#     path = f'./saved/{seed}{submethod}{Optimizer}_{epoch}_{name}_{dataset}.pt'
#     dictionary = torch.load(path, map_location=device)
#     cos = dictionary['cos']
#     split_dict(cos, epoch, lb, hb)
#
#     worksheet_cos = workbook.add_worksheet(f'{name}')
#
#     write_lw(cos, worksheet_cos, flood=0.0)
#
#     # --------------------------------------------------------------------------------------------------
#     # name = 'lw_cos'
#     # path = f'./saved/{seed}{submethod}{Optimizer}_{epoch}_{name}_{dataset}.pt'
#     # dictionary = torch.load(path, map_location=device)
#     # cos = dictionary['cos']
#     worksheet_cos = workbook.add_worksheet(f'{name}_task_avg')
#
#     write_lw_task_avg(cos, worksheet_cos)
#
#     # --------------------------------------------------------------------------------------------------
#     name = 'lw_task_cos'
#     path = f'./saved/{seed}{submethod}{Optimizer}_{epoch}_{name}_{dataset}.pt'
#     dictionary = torch.load(path, map_location=device)
#     task_cos = dictionary['task_cos']
#     split_dict(task_cos, epoch, lb, hb)
#
#     worksheet_task_cos = workbook.add_worksheet(f'{name}')
#
#     write_lw(task_cos, worksheet_task_cos, flood=flood)
#
#     # --------------------------------------------------------------------------------------------------
#     # name = 'lw_task_cos'
#     # path = f'./saved/{seed}{submethod}{Optimizer}_{epoch}_{name}_{dataset}.pt'
#     # dictionary = torch.load(path, map_location=device)
#     # cos = dictionary['task_cos']
#     worksheet_cos = workbook.add_worksheet(f'{name}_task_avg')
#
#     write_lw_task_avg(cos, worksheet_cos)
#
#     workbook.close()

