import numpy as np
import torch
import torch.nn as nn
import numpy
import xlsxwriter

from collections import OrderedDict

def conflict_prob(angles, lb, hb, flood=0.0, n_tasks=2):
    if not isinstance(flood, list):
        flood = [flood]


    n_pairs = int(n_tasks * (n_tasks - 1) / 2)
    angles = angles[lb:hb]

    n_conflict = [0 for i in range(len(flood) + 1)]

    for value in angles:
        for i in range(len(flood) + 1):
            if i == 0:
                conflict = (value > flood[i])
            elif i < len(flood):
                conflict = (value < flood[i - 1] and value >= flood[i])
            elif i == len(flood):
                conflict = (value < flood[i - 1])

            conflict = conflict.sum().item()
            n_conflict[i] += conflict

    p_conflict = [round(c / (len(angles) * n_pairs) * 100, 2) for c in n_conflict]

    return n_conflict, p_conflict

if __name__ == '__main__':
    device = torch.device('cuda:0')

    # sub_method_L = ['nothing_v2', 'cagrad', 'graddrop']
    sub_method_L = ['mgd']
    flood = [0.0, -0.02, -0.04, -0.06, -0.08, -0.1]
    # flood = [0.0, -0.02, -0.04, -0.06]
    topK = 39


    for submethod in sub_method_L:
        Optimizer = 'Adam'
        base_model = 'fw_b_na'
        epoch = 200
        seed = 0

        lb = 0
        hb = 200

        # # --------------------------------------------------------------------------------------------------
        path = f'./saved/{seed}{base_model}{topK}{submethod}{Optimizer}_{epoch}_angle.pt'
        angles = torch.load(path, map_location=device)
        batch_size = len(angles) / epoch
        lb_batch = int(batch_size * lb)
        hb_batch = int(batch_size * hb)
        n_conflict, prob = conflict_prob(angles, lb_batch, hb_batch, flood)

        print(f'{submethod}')
        print(f'{base_model}{topK}{submethod}, [{lb}, {hb}]: {n_conflict})')
        print(f'{base_model}{topK}{submethod}, [{lb}, {hb}]: {prob})')

    # flood = -0.01
    # n_conflict, prob = conflict_prob(angles, lb_batch, hb_batch, flood)
    # print(f'{base_model}{topK}{submethod}, [{lb}, {hb}], flood={flood}:   {n_conflict}({prob:.4f})')
    #
    # flood = -0.02
    # n_conflict, prob = conflict_prob(angles, lb_batch, hb_batch, flood)
    # print(f'{base_model}{topK}{submethod}, [{lb}, {hb}], flood={flood}: {n_conflict}({prob:.4f})')
    #
    # flood = -0.03
    # n_conflict, prob = conflict_prob(angles, lb_batch, hb_batch, flood)
    # print(f'{base_model}{topK}{submethod}, [{lb}, {hb}], flood={flood}: {n_conflict}({prob:.4f})')

    labels = ['[1.0, 0.0]', '(0.0, -0.02]', '(-0.02, -0.04]', '(-0.04, -0.06]', '(-0.06, -1.0]']

    def divd(list, d=100.0):
        return [v / 100.0 for v in list]

    joint_train_recon =[73.62, 20.13, 5.13, 0.94, 0.17]
    joint_train_recon = divd(joint_train_recon)
    cagrad_recon = [74.54, 19.77, 4.62, 0.89, 0.18]
    cagrad_recon = divd(cagrad_recon)
    graddrop_recon = [73.82, 19.75, 5.17, 1.05, 0.21]
    graddrop_recon = divd(graddrop_recon)
    pcgrad_recon = [74.52, 19.43, 4.89, 0.96, 0.2]
    pcgrad_recon = divd(pcgrad_recon)


    joint_train = [59.55, 10.14, 8.52, 6.45, 15.33]
    joint_train = divd(joint_train)
    cagrad = [60.79, 11.13, 8.83, 6.05, 13.19]
    cagrad = divd(cagrad)
    graddrop = [59.56, 9.61, 8.19, 6.49, 16.14]
    graddrop = divd(graddrop)
    pcgrad = [59.85, 9.58, 7.94, 6.24, 16.39]
    pcgrad = divd(pcgrad)
    jointtrain_CPS = [58.29, 10.77, 8.72, 6.48, 15.74]
    jointtrain_CPS = divd(jointtrain_CPS)
    jointtrain_RSL = [53.16, 9.01, 7.34, 5.69, 24.79]
    jointtrain_RSL = divd(jointtrain_RSL)
