import torch
from mnist import MLP, CNN, test, nonlinearities
import argparse
import torch.nn.functional as F
from torchvision import datasets, transforms
from scipy.optimize import linear_sum_assignment
from collections import OrderedDict
import copy
import plotly.express as px
import plotly.graph_objects as go
import numpy
import time
import random
import copy

def main():
    parser = argparse.ArgumentParser(description='Roto-Equivariance')
    parser.add_argument('--model0', type=str, help='Load the base model')
    parser.add_argument('--model1', type=str, help='Load the alternative model')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--network', type=str, default="cnn",
                        help='cnn or mlp (default:cnn)')
    parser.add_argument('--nonlinearity', type=str, default="ReLU",
                        help='Activation function e.g. Co+[ReLU/SiLU](+[_1/_neg/_inf])')
    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    test_kwargs = {'batch_size': args.test_batch_size}
    if use_cuda:
        cuda_kwargs = {'num_workers': 1,
                       'pin_memory': True,
                       'shuffle': True}
        test_kwargs.update(cuda_kwargs)

    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,)),
        transforms.Resize(32,antialias=True)
        ])
#     dataset1 = datasets.MNIST('../data', train=True, download=True,
#                        transform=transform)
    dataset2 = datasets.MNIST('../data', train=False,
                       transform=transform)
#     train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
    test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
    
    sd0 = torch.load(args.model0) # fixed
    sd1 = torch.load(args.model1) # fixed 
    sd2 = torch.load(args.model1) # permuted from sd1
    sd = torch.load(args.model1) # container
        
    print('layer shapes:', [(key,sd[key].shape) for key in sd])
    # compute permutation P
    P = OrderedDict()
    key0 = list(sd0.keys())[0]
    P[0] = torch.arange(sd0[key0].shape[1]) # input channel P[0]
    for key in sd0:
        if len(sd0[key].shape) > 1: # ignore bias
        # save P for sd['conv1.weight'], sd['fc1.weight'], not for sd['conv1.bias'], sd['fc1.bias']
            P[key] = torch.arange(sd0[key].shape[0])
            key_1 = key
#     P[-1] = P.pop(key_1) # output channel P[-1]
    keys = list(P.keys())
    print('Computing P with layers:', keys)
    old_loss = - 1e9
    K = 50
    for k in range(K):
        new_loss = 0
#         permuted_keys = copy.deepcopy(keys)
#         random.shuffle(permuted_keys)
        for i, key in enumerate(keys):
            if key in [0, keys[-1]]:
                continue
            prevkey = keys[i-1]
            nextkey = keys[i+1]
            if len(sd0[key].shape) > 2: # conv
                assert len(sd0[key].shape) == 4
                C = torch.einsum('oikl,Oikl->oO',sd0[key],sd1[key][:,P[prevkey]]) \
                   +torch.einsum('oikl,oIkl->iI',sd0[nextkey],sd1[nextkey][P[nextkey]])
            else: # fc
                assert len(sd0[key].shape) == 2
                C = torch.einsum('oi,Oi->oO',sd0[key],sd1[key][:,P[prevkey]]) \
                   +torch.einsum('oi,oI->iI',sd0[nextkey],sd1[nextkey][P[nextkey]])
            C = C.cpu()
            ri, ci = linear_sum_assignment(C, maximize = True)
            new_loss += sum(C[ri,ci])
            P[key] = ci#.cuda()
        print('New cost:', new_loss)
        if abs(new_loss - old_loss) < 1e-4: # converges
            print('greedy algorithm terminates at round', k+1, '/', K)
            break
        old_loss = new_loss
    
    print("Permutations:")
    print(P)
    # apply P
    for i, key in enumerate(keys):
        if key == 0: # skip the first layer
            continue
        prevkey = keys[i-1]
        print('processing', key, '(second dim) , len(p_prev) = ', len(P[prevkey]))
        sd2[key] = sd2[key][:,P[prevkey]]
        print('processing', key, ', len(p) = ', len(P[key]))
        sd2[key] = sd2[key][P[key]]
        key_bias = key[:-6]+'bias'
        if key_bias in sd2:
            print('processing', key_bias, ', len(p) = ', len(P[key]))
            sd2[key_bias] = sd2[key_bias][P[key]]
            
####        
## the above iteration is to align layers as such:
#p = P['conv1.weight']
#sd2['conv1.weight'] = sd2['conv1.weight'][p]
#sd2['conv1.bias'] = sd2['conv1.bias'][p]
#sd2['conv2.weight'] = sd2['conv2.weight'][:,p]
#p = P['conv2.weight']
#sd2['conv2.weight'] = sd2['conv2.weight'][p]
#sd2['conv2.bias'] = sd2['conv2.bias'][p]
#sd2['conv3.weight'] = sd2['conv3.weight'][:,p]
## , etc.
## remember to check the equivariance property
####

    Net = CNN if args.network == 'cnn' else MLP
    model = Net(args.nonlinearity).to(device)
    model.load_state_dict(sd2)
    print("after alignment:")
    loss2 = test(model, device, test_loader)
    model.load_state_dict(sd1)
    print("before alignment:")
    loss1 = test(model, device, test_loader)
#     assert abs(loss1 - loss2) < 1e-3
        
###interpolation before/after permutation
    ind = torch.arange(0,1.001,.05)
    def interpolate(sd0,sd1,sd=None):
        if not sd:
            sd = copy.deepcopy(sd1)
        accs = []
        for s in ind:
            model = Net(args.nonlinearity).to(device)
            for key in sd:
                sd[key] = (1 - s) * sd0[key] + s * sd1[key]
            model.load_state_dict(sd)
            print("s =", s, ":")
            acc = test(model, device, test_loader)
            accs.append(acc)
        return accs
    print("after alignment:")
    accs_after = interpolate(sd0,sd2,sd)
    print("before alignment:")
    accs_before = interpolate(sd0,sd1,sd)
    fig = go.Figure()
    fig.write_image("{}_perm.pdf".format(args.nonlinearity))
    fig.add_trace(go.Scatter(x=ind,y=accs_before,mode='lines',name='before alignment'))
    fig.add_trace(go.Scatter(x=ind,y=accs_after,mode='lines',name='after alignment'))
    fig.update_layout(
    title={
        'text': args.nonlinearity,
        'y':0.9,
        'x':0.5,
        'xanchor': 'center',
        'yanchor': 'top'},
    legend=dict(
        yanchor="top",
        y=0.99,
        xanchor="right",
        x=0.99
    ))
    time.sleep(2) # load [MathJax]/extensions/MathMenu.js
    fig.write_image("{}_perm.pdf".format(args.nonlinearity))

if __name__ == '__main__':
    main()
    