# python roto_symmetry.py --nonlinearity ReLU --model0 ckpt/mnist_cnn_ReLU_seed_23.pt --model1 ckpt/mnist_cnn_ReLU_seed_2023.pt
import torch
from mnist import CNN, MLP, test, nonlinearities
import argparse
import torch.nn.functional as F
from torchvision import datasets, transforms
from scipy.optimize import linear_sum_assignment
from collections import OrderedDict
import copy
import plotly.express as px
import plotly.graph_objects as go
import numpy
import time
import random
# import copy
import ot

def align(a,p,dim=0):
    d = len(a.shape)
    if d == 4:
        if dim == 1:
            return torch.einsum('oikl,iI->oIkl',a,p)
        elif dim == 0:
            return torch.einsum('oikl,oO->Oikl',a,p)
        else:
            raise NotImplementedError
    elif d == 2:
        if dim == 1:
            return torch.einsum('oi,Ii->oI',a,p)
        elif dim == 0:
            return torch.einsum('oi,Oo->Oi',a,p)
        else:
            raise NotImplementedError
    elif d == 1:
        if dim == 0:
            return torch.einsum('o,oO->O',a,p)
        else:
            raise NotImplementedError
    else:
        raise NotImplementedError
        

def gram(a,b,dim=0):
    d = len(a.shape)
    e = len(b.shape)
    assert d == e
    if d == 4:
        if dim == 1:
            return torch.einsum('oikl,oIkl->iI',a,b)
        elif dim == 0:
            return torch.einsum('oikl,Oikl->oO',a,b)
        else:
            raise NotImplementedError
    elif d == 2:
        if dim == 1:
            return torch.einsum('oi,oI->iI',a,b)
        elif dim == 0:
            return torch.einsum('oi,Oi->oO',a,b)
        else:
            raise NotImplementedError
    else:
        raise NotImplementedError

def main():
    parser = argparse.ArgumentParser(description='Roto-Equivariance')
    parser.add_argument('--model0', type=str, help='Load the principal model')
    parser.add_argument('--model1', type=str, help='Load the alternative model')
    parser.add_argument('--epochs', type=int, default=14, metavar='N',
                        help='number of epochs to train (default: 14)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--test-batch-size', type=int, default=50, metavar='N',
                        help='input batch size for testing (default: 10)')
    parser.add_argument('--network', type=str, default="cnn",
                        help='cnn or mlp (default:cnn)')
    parser.add_argument('--nonlinearity', type=str, default="ReLU",
                        help='Activation function e.g. Co+[ReLU/SiLU](+[_1/_neg/_inf])')
    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()
    torch.manual_seed(args.seed)
    device = torch.device("cuda" if use_cuda else "cpu")

    test_kwargs = {'batch_size': args.test_batch_size}
    if use_cuda:
        cuda_kwargs = {'num_workers': 10,
                       'pin_memory': True,
                       'shuffle': True}
        test_kwargs.update(cuda_kwargs)

    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,)),
        transforms.Resize(32,antialias=True)
        ])
#     dataset1 = datasets.MNIST('../data', train=True, download=True,
#                        transform=transform)
    dataset2 = datasets.MNIST('../data', train=False,
                       transform=transform)
#     train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
    test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
    
    sd0 = torch.load(args.model0) # fixed
    sd1 = torch.load(args.model1) # fixed 
    sd2 = torch.load(args.model1) # permuted from sd1
    sd = torch.load(args.model1) # container
        
    print('layer shapes:', [(key,sd[key].shape) for key in sd])
    # compute permutation P
    P = OrderedDict()
    key0 = list(sd0.keys())[0]
    P[0] = torch.eye(sd0[key0].shape[1],device=device) # input channel P[0]
    for key in sd0:
        if len(sd0[key].shape) > 1: # ignore bias
        # save P for sd['conv1.weight'], sd['fc1.weight'], not for sd['conv1.bias'], sd['fc1.bias']
            P[key] = torch.eye(sd0[key].shape[0],device=device)
            key_1 = key
#     P[-1] = P.pop(key_1) # output channel P[-1]
    keys = list(P.keys())
    print('computing P with layers:', keys)
    old_loss = - 1e9
    K = 500
    permuted_keys = [x for x in enumerate(keys)]
    for k in range(K):
        new_loss = 0
        random.shuffle(permuted_keys)
        for i, key in permuted_keys:
            if key in [0, keys[-1]]:
                continue
            prevkey = keys[i-1]
            nextkey = keys[i+1]
            C = gram(sd0[key],align(sd1[key],P[prevkey],1),0) \
              + gram(sd0[nextkey],align(sd1[nextkey],P[nextkey],0),1)
            if args.nonlinearity[:2] == "Co":
                C = C[1:,1:]
                a = b = torch.ones(P[key].shape[0]-1,device=device)
                P[key][1:,1:] = torch.Tensor(ot.emd(a,b,-C).to(device))
                new_loss = sum(sum(C*P[key][1:,1:]))
            else:
                a = b = torch.ones(P[key].shape[0],device=device)
                P[key] = ot.emd(a,b,-C).to(device)
                new_loss = sum(sum(C*P[key]))
#             eps = 1e-1
#             P[key] = ot.sinkhorn(a,b,-C,eps,numItermax=1000).to(device)
        print('New cost:', new_loss)
        if abs(new_loss - old_loss) < 1e-2: # converges
#         if new_loss - old_loss < 1e-2: # converges
            print('greedy algorithm terminates at round', k+1, '/', K)
            break
        old_loss = new_loss
        
    
    print("OT Plans:")
    print(P)
    # apply P
    for i, key in enumerate(keys):
        if key == 0: # skip the first layer
            continue
        prevkey = keys[i-1]
        print('processing', key, '(second dim) , len(p_prev) = ', len(P[prevkey]))
        sd2[key] = align(sd2[key],P[prevkey],1)
        print('processing', key, ', len(p) = ', len(P[key]))
        sd2[key] = align(sd2[key],P[key],0)
        key_bias = key[:-6]+'bias'
        if key_bias in sd2:
            print('processing', key_bias, ', len(p) = ', len(P[key]))
            sd2[key_bias] = align(sd2[key_bias],P[key])
        
#     # the above iteration is to align layers as such:
#     p = P['conv1.weight']
#     sd2['conv1.weight'] = sd2['conv1.weight'][p]
#     sd2['conv1.bias'] = sd2['conv1.bias'][p]
#     sd2['conv2.weight'] = sd2['conv2.weight'][:,p]
#     p = P['conv2.weight']
#     sd2['conv2.weight'] = sd2['conv2.weight'][p]
#     sd2['conv2.bias'] = sd2['conv2.bias'][p]
#     sd2['conv3.weight'] = sd2['conv3.weight'][:,p]
#     # , etc.

#     # check the equivariance property
    Net = CNN if args.network == 'cnn' else MLP
    model = Net(args.nonlinearity).to(device)
    model.load_state_dict(sd2)
    print("after alignment:")
    loss2 = test(model, device, test_loader)
    model.load_state_dict(sd1)
    print("before alignment:")
    loss1 = test(model, device, test_loader)
#     assert abs(loss1 - loss2) < 1e-2
        
#     interpolation before/after permutation
    ind = torch.arange(0,1.001,.05)
    def interpolate(sd0,sd1,sd=None):
        if not sd:
            sd = copy.deepcopy(sd1)
        accs = []
        for s in ind:
            model = Net(args.nonlinearity).to(device)
            for key in sd:
                sd[key] = (1 - s) * sd0[key] + s * sd1[key]
            model.load_state_dict(sd)
            print("s =", s, ":")
            acc = test(model, device, test_loader)
            accs.append(acc)
        return accs
    print("after alignment:")
    accs_after = interpolate(sd0,sd2,sd)
    print("before alignment:")
    accs_before = interpolate(sd0,sd1,sd)
    fig = go.Figure()
    fig.write_image("{}_roto.pdf".format(args.nonlinearity))
    fig.add_trace(go.Scatter(x=ind,y=accs_before,mode='lines',name='before alignment'))
    fig.add_trace(go.Scatter(x=ind,y=accs_after,mode='lines',name='after alignment'))
    fig.update_layout(
    title={
        'text': args.nonlinearity,
        'y':0.9,
        'x':0.5,
        'xanchor': 'center',
        'yanchor': 'top'},
    legend=dict(
        yanchor="top",
        y=0.99,
        xanchor="right",
        x=0.99
    ))
    fig.write_image("fig/roto_symm_"+args.nonlinearity+".pdf")
    time.sleep(2) # load [MathJax]/extensions/MathMenu.js
    fig.write_image("{}_roto.pdf".format(args.nonlinearity))

if __name__ == '__main__':
    main()
    