import numpy as np
import torch
import torch.nn as nn
import numpy
import xlsxwriter

from collections import OrderedDict

def turn_permute_to_dict(p):
    dict = {}
    for i, value in enumerate(p):
        dict[value] = i
    return dict

def diff(dict1, dict2):
    sum = 0
    for key, v1 in dict1.items():
        v2 = dict2[key]
        sum += abs(v1 - v2)
    avg = sum / len(list(dict1.keys()))
    return avg

if __name__ == '__main__':
    pcgrad = [
        'bn1',
        'layer1.0.bn1',
        'layer1.1.bn1',
        'layer1.0.bn2',
        'layer1.1.bn2',
        'layer2.0.bn1',
        'layer2.1.bn1',
        'conv1',
        'layer2.0.downsample.1',
        'layer2.0.bn2',
        'layer2.1.bn2',
        'layer1.0.conv1',
        'layer3.0.bn1',
        'layer3.1.bn1',
        'layer4.1.bn1',
        'layer3.0.downsample.1',
        'layer3.0.bn2',
        'layer4.0.bn1',
        'layer4.1.bn2',
        'layer3.1.bn2',
        'layer4.0.downsample.1',
        'layer4.0.bn2',
        'layer1.0.conv2',
        'layer1.1.conv1',
        'layer1.1.conv2',
        'layer2.0.downsample.0',
        'layer2.0.conv1',
        'layer2.1.conv2',
        'layer2.0.conv2',
        'fc',
        'layer2.1.conv1',
        'layer3.0.downsample.0',
        'layer3.0.conv2',
        'layer3.0.conv1',
        'layer3.1.conv1',
        'layer4.0.conv1',
        'layer4.1.conv2',
        'layer3.1.conv2',
        'layer4.0.downsample.0',
        'layer4.1.conv1',
        'layer4.0.conv2',
    ]

    cagrad = [
        'bn1',
        'layer1.0.bn1',
        'layer1.1.bn1',
        'layer1.0.bn2',
        'layer1.1.bn2',
        'layer2.0.bn1',
        'layer2.1.bn1',
        'conv1',
        'layer2.0.bn2',
        'layer2.0.downsample.1',
        'layer2.1.bn2',
        'layer3.0.bn1',
        'layer1.0.conv1',
        'layer3.1.bn1',
        'layer4.1.bn1',
        'layer4.0.bn1',
        'layer3.0.bn2',
        'layer3.0.downsample.1',
        'layer4.1.bn2',
        'layer3.1.bn2',
        'layer4.0.downsample.1',
        'layer4.0.bn2',
        'layer1.0.conv2',
        'layer1.1.conv1',
        'layer1.1.conv2',
        'fc',
        'layer2.0.downsample.0',
        'layer2.0.conv1',
        'layer4.1.conv2',
        'layer2.0.conv2',
        'layer2.1.conv2',
        'layer3.0.conv1',
        'layer3.0.downsample.0',
        'layer2.1.conv1',
        'layer4.0.conv1',
        'layer3.1.conv2',
        'layer4.1.conv1',
        'layer3.1.conv1',
        'layer4.0.downsample.0',
        'layer3.0.conv2',
        'layer4.0.conv2',
    ]

    graddrop = [
        'bn1',
        'layer1.0.bn1',
        'layer1.1.bn1',
        'layer1.0.bn2',
        'layer1.1.bn2',
        'layer2.0.bn1',
        'layer2.1.bn1',
        'conv1',
        'layer2.0.bn2',
        'layer2.0.downsample.1',
        'layer2.1.bn2',
        'layer3.0.bn1',
        'layer1.0.conv1',
        'layer3.1.bn1',
        'layer4.1.bn1',
        'layer3.0.bn2',
        'layer4.0.bn1',
        'layer3.0.downsample.1',
        'layer3.1.bn2',
        'layer4.1.bn2',
        'layer4.0.downsample.1',
        'layer4.0.bn2',
        'layer1.0.conv2',
        'layer1.1.conv1',
        'layer1.1.conv2',
        'layer2.0.downsample.0',
        'layer2.0.conv1',
        'fc',
        'layer2.0.conv2',
        'layer3.0.downsample.0',
        'layer2.1.conv1',
        'layer2.1.conv2',
        'layer3.0.conv1',
        'layer3.0.conv2',
        'layer4.0.conv1',
        'layer4.1.conv2',
        'layer3.1.conv2',
        'layer3.1.conv1',
        'layer4.1.conv1',
        'layer4.0.downsample.0',
        'layer4.0.conv2',
    ]

    mgd = [
        'bn1',
        'layer1.0.bn1',
        'layer1.1.bn1',
        'layer1.0.bn2',
        'layer1.1.bn2',
        'conv1',
        'layer2.0.bn1',
        'layer2.1.bn1',
        'layer2.0.bn2',
        'layer2.0.downsample.1',
        'layer2.1.bn2',
        'layer3.0.bn1',
        'layer4.1.bn1',
        'layer3.1.bn1',
        'layer1.0.conv1',
        'layer3.0.bn2',
        'layer3.0.downsample.1',
        'layer3.1.bn2',
        'layer4.0.downsample.1',
        'layer4.0.bn1',
        'layer4.0.bn2',
        'layer4.1.bn2',
        'fc',
        'layer4.1.conv2',
        'layer1.0.conv2',
        'layer1.1.conv1',
        'layer1.1.conv2',
        'layer2.0.downsample.0',
        'layer2.0.conv1',
        'layer2.0.conv2',
        'layer2.1.conv2',
        'layer2.1.conv1',
        'layer3.0.downsample.0',
        'layer3.0.conv1',
        'layer4.0.conv1',
        'layer4.1.conv1',
        'layer3.0.conv2',
        'layer3.1.conv1',
        'layer4.0.conv2',
        'layer3.1.conv2',
        'layer4.0.downsample.0',

    ]

    nothing = [
        'bn1',
        'layer1.0.bn1',
        'layer1.1.bn1',
        'layer1.0.bn2',
        'layer1.1.bn2',
        'layer2.1.bn1',
        'layer2.0.bn1',
        'conv1',
        'layer2.0.bn2',
        'layer2.0.downsample.1',
        'layer2.1.bn2',
        'layer3.0.bn1',
        'layer1.0.conv1',
        'layer3.1.bn1',
        'layer4.1.bn1',
        'layer3.0.bn2',
        'layer4.1.bn2',
        'layer3.0.downsample.1',
        'layer4.0.bn1',
        'layer3.1.bn2',
        'layer4.0.downsample.1',
        'layer4.0.bn2',
        'layer1.0.conv2',
        'layer1.1.conv1',
        'layer1.1.conv2',
        'layer2.0.downsample.0',
        'fc',
        'layer2.0.conv1',
        'layer2.0.conv2',
        'layer3.0.downsample.0',
        'layer2.1.conv2',
        'layer2.1.conv1',
        'layer3.1.conv1',
        'layer3.0.conv2',
        'layer3.0.conv1',
        'layer4.0.conv1',
        'layer3.1.conv2',
        'layer4.1.conv1',
        'layer4.1.conv2',
        'layer4.0.downsample.0',
        'layer4.0.conv2',
    ]

    pcgrad = turn_permute_to_dict(pcgrad)
    cagrad = turn_permute_to_dict(cagrad)
    graddrop = turn_permute_to_dict(graddrop)
    mgd = turn_permute_to_dict(mgd)
    nothing = turn_permute_to_dict(nothing)

    keys = ['pcgrad', 'cagrad', 'graddrop', 'mgd', 'nothing']
    value = [pcgrad, cagrad, graddrop, mgd, nothing]

    for i in range(len(keys) - 1):
        for j in range(i + 1, len(keys)):
            print(f'diff between {keys[i]} and {keys[j]}')
            avg = diff(value[i], value[j])
            print(f'avg: {avg:.2f}')
