import numpy as np
import torch
import torch.nn as nn
import numpy
import xlsxwriter

from collections import OrderedDict

def conflict_prob(angles, lb, hb, flood=0.0, n_tasks=2):
    if not isinstance(flood, list):
        flood = [flood]


    n_pairs = int(n_tasks * (n_tasks - 1) / 2)
    angles = angles[lb:hb]

    n_conflict = [0 for i in range(len(flood) + 1)]

    for value in angles:
        for i in range(len(flood) + 1):
            if i == 0:
                conflict = (value > flood[i])
            elif i < len(flood):
                conflict = (value < flood[i - 1] and value >= flood[i])
            elif i == len(flood):
                conflict = (value < flood[i - 1])

            conflict = conflict.sum().item()
            n_conflict[i] += conflict

    p_conflict = [round(c / (len(angles) * n_pairs) * 100, 2) for c in n_conflict]

    return n_conflict, p_conflict

if __name__ == '__main__':
    device = torch.device('cuda:0')
    submethod = 'nothing_v2'
    Optimizer = 'SGD'

    base_model = 'fw_ablation'
    flag='RSL'
    topK = 15

    epoch = 120
    seed = 0
    flood = [0.0, -0.01, -0.02, -0.03]
    lb = 0
    hb = 120

    # # --------------------------------------------------------------------------------------------------
    path = f'./saved/{seed}{base_model}{flag}_{topK}{submethod}{Optimizer}_{epoch}_angle.pt'
    angles = torch.load(path, map_location=device)
    batch_size = len(angles) / epoch
    lb_batch = int(batch_size * lb)
    hb_batch = int(batch_size * hb)
    n_conflict, prob = conflict_prob(angles, lb_batch, hb_batch, flood)
    print(f'{base_model}{topK}{submethod}, [{lb}, {hb}]: {n_conflict})')
    print(f'{base_model}{topK}{submethod}, [{lb}, {hb}]: {prob})')

    # flood = -0.01
    # n_conflict, prob = conflict_prob(angles, lb_batch, hb_batch, flood)
    # print(f'{base_model}{topK}{submethod}, [{lb}, {hb}], flood={flood}:   {n_conflict}({prob:.4f})')
    #
    # flood = -0.02
    # n_conflict, prob = conflict_prob(angles, lb_batch, hb_batch, flood)
    # print(f'{base_model}{topK}{submethod}, [{lb}, {hb}], flood={flood}: {n_conflict}({prob:.4f})')
    #
    # flood = -0.03
    # n_conflict, prob = conflict_prob(angles, lb_batch, hb_batch, flood)
    # print(f'{base_model}{topK}{submethod}, [{lb}, {hb}], flood={flood}: {n_conflict}({prob:.4f})')