from itertools import islice

import copy
import torch
import ot
import matplotlib.pyplot as plt
from tqdm import tqdm

import mobilenetv2

from torchvision.datasets import CIFAR100
from torchvision import transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
import graphviz
from networkx.utils.union_find import UnionFind


def find_groups(model):
    unionfind = UnionFind()
    visited = set()

    def traverse(layer_name):
        visited.add(layer_name)
        deps = model.dependent_layers(layer_name)
        unionfind.union(f"{layer_name}.in", *[f"{dep}.out" for dep in deps])
        for dep in deps:
            if dep not in visited:
                traverse(dep)

    traverse('linear')

    result_dict = {}
    for group in unionfind.to_sets():
        for item in group:
            result_dict[item] = group
    return result_dict, unionfind


def find_align_matrix(n_in, src_layer, src_model, tgt_layer, tgt_model, train_loader, n_batches):
    src_inputs = get_inputs_for(src_layer, src_model, train_loader, n_batches)  # (n_in, batch_size)
    tgt_inputs = get_inputs_for(tgt_layer, tgt_model, train_loader, n_batches)  # (n_in, batch_size)
    print(f"{src_inputs.numel() * 4:,d}, {src_inputs.shape}")

    src_weights = torch.full([n_in], 1 / n_in, device='cuda')
    tgt_weights = torch.full([n_in], 1 / n_in, device='cuda')

    D = torch.cdist(src_inputs[None], tgt_inputs[None], p=2).squeeze(0)
    D /= D.max()

    T = ot.emd(src_weights, tgt_weights, D)
    return T


@torch.no_grad()
def connect(
        src_model: mobilenetv2.MobileNetV2,
        tgt_model: mobilenetv2.MobileNetV2,
        train_loader,
        use_activations=False,
        n_batches=20):
    src_modules = dict(src_model.named_modules())
    tgt_modules = dict(tgt_model.named_modules())
    groups, unionfind = find_groups(src_model)

    for group in unionfind.to_sets():
        # align layer input & other groups
        print('-'*5 + f" {group} " + '-'*5)

        layer_name = None
        for align_name in group:
            if align_name.endswith('.in'):
                layer_name = align_name[:-3]
                break
        print(layer_name)

        # align group
        src_layer, tgt_layer = src_modules[layer_name], tgt_modules[layer_name]
        n_in = src_layer.weight.shape[1]
        T = find_align_matrix(
            n_in,
            src_layer, src_model,
            tgt_layer, tgt_model,
            train_loader, n_batches
        )
        tgt_weights = torch.full([n_in], 1 / n_in, device='cuda')

        # align all layers in the group
        for align_name in group:
            print(align_name)

            # obtain layer weights
            layer_name = '.'.join(align_name.split('.')[:-1])
            w_src = src_modules[layer_name].weight
            new_shape = (w_src.size(0), w_src.size(1), -1)
            w_src = w_src.view(*new_shape)  # (n_out, n_in, _)

            # permutation matrices
            transform = T @ torch.diag(1 / tgt_weights)

            if align_name.endswith('.in'):
                # align current layer inputs
                w_src.copy_(torch.einsum('it,oif->otf', transform, w_src))
            else:
                # align successor outputs
                w_src.copy_(torch.einsum('ot,oif->tif', transform, w_src))

                # depthwise layers
                if align_name.endswith('.conv1.out'):
                    depthwise_name = layer_name.replace('conv1', 'conv2')
                    dw_src = src_modules[depthwise_name].weight
                    dw_src.copy_(torch.einsum('ot,oabc->tabc', transform, dw_src))

                    # depthwise batchnorm layers
                    dw_bn_name = depthwise_name.replace('conv2', 'bn2')
                    bn_layer = src_modules[dw_bn_name]
                    bn_layer.weight.copy_(torch.einsum('ot,o->t', transform, bn_layer.weight))
                    bn_layer.bias.copy_(torch.einsum('ot,o->t', transform, bn_layer.bias))
                    bn_layer.running_var.copy_(torch.einsum('ot,o->t', transform, bn_layer.running_var))
                    bn_layer.running_mean.copy_(torch.einsum('ot,o->t', transform, bn_layer.running_mean))

                # batchnorm layers
                if 'shortcut' in layer_name:
                    bn_name = layer_name[:-1] + "1"
                elif 'conv' in layer_name:
                    bn_name = layer_name.replace('conv', 'bn')
                else:
                    bn_name = None
                if bn_name is not None:
                    bn_layer = src_modules[bn_name]
                    bn_layer.weight.copy_(torch.einsum('ot,o->t', transform, bn_layer.weight))
                    bn_layer.bias.copy_(torch.einsum('ot,o->t', transform, bn_layer.bias))
                    bn_layer.running_var.copy_(torch.einsum('ot,o->t', transform, bn_layer.running_var))
                    bn_layer.running_mean.copy_(torch.einsum('ot,o->t', transform, bn_layer.running_mean))


def get_inputs_for(layer, model, loader, n_batches):
    model.eval()
    results = None
    begin = 0
    for X, y in tqdm(islice(loader, 0, n_batches), total=n_batches, ncols=60):
        X = model.inputs_for(layer, X.cuda(non_blocking=True))
        X = X.permute(1, 0, 2, 3).reshape(X.shape[1], -1)
        if results is None:
            results = torch.empty(X.shape[0], n_batches * X.shape[1], device=X.device)
        results[:, begin:begin + X.shape[1]] = X
        begin += X.shape[1]
    return results


def test_model(model, val_loader: DataLoader):
    model.eval()
    print("Test")
    loss_accum = 0
    correct_count = 0
    criterion = torch.nn.CrossEntropyLoss(reduction='sum')
    with torch.no_grad():
        for X, y in val_loader:
            X = X.cuda(non_blocking=True)
            y = y.cuda(non_blocking=True)
            logits = model(X)
            loss = criterion(logits, y)
            loss_accum += loss.data
            correct_count += (torch.argmax(logits, axis=1) == y).sum()
        loss = (loss_accum / len(val_loader.dataset)).item()
        acc = (correct_count / len(val_loader.dataset)).item()
        print(f'{loss}')
        print(f'{acc:.4f}')
    return loss, acc


def train_acc_model(model, train_loader: DataLoader):
    model.train()
    print("Train")
    loss_accum = 0
    correct_count = 0
    criterion = torch.nn.CrossEntropyLoss(reduction='sum')
    with torch.no_grad():
        for X, y in train_loader:
            X = X.cuda(non_blocking=True)
            y = y.cuda(non_blocking=True)
            logits = model(X)
            loss = criterion(logits, y)
            loss_accum += loss.data
            correct_count += (torch.argmax(logits, axis=1) == y).sum()
        loss = (loss_accum / len(train_loader.dataset)).item()
        acc = (correct_count / len(train_loader.dataset)).item()
        print(f'{loss}')
        print(f'{acc:.4f}')
    return loss, acc


def load_model(path):
    model = mobilenetv2.MobileNetV2(num_classes=100, final_bias_trick=False).cuda()
    state_dict = torch.load(path, map_location='cpu')
    state_dict = {k: p for k, p in state_dict.items() if 'quantizer' not in k}
    if 'linear.bias' not in state_dict:
        state_dict['linear.bias'] = state_dict['linear.weight'][:, -1]
        state_dict['linear.weight'] = state_dict['linear.weight'][:, :-1]
    model.load_state_dict(state_dict)
    return model


def render_dep_graph(model):
    # render dependency graph
    dot = graphviz.Graph('dep-graph')
    marked = set()

    def make_dep_tree(model, layer):
        dot.node(layer, layer)
        marked.add(layer)
        for dep in model.dependent_layers(layer):
            dot.edge(dep, layer)
            if dep not in marked:
                make_dep_tree(model, dep)

    make_dep_tree(model, 'linear')
    dot.render('dep-graph')


def main():
    model_0 = load_model('/d1/xxx/TBQ_experiments/model_fp_0.pt')
    model_1 = load_model('/d1/xxx/TBQ_experiments/model_fp_1.pt')
    model_2 = load_model('/d1/xxx/TBQ_experiments/model_fp_2.pt')

    # render_dep_graph(model_0)

    # load dataset
    DATASET_DIR = '/home/xxx/DBQ/'
    train_set = CIFAR100(DATASET_DIR, train=True, download=True,
                         transform=transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
                         ]))
    test_set = CIFAR100(DATASET_DIR, train=False, download=True,
                        transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
                        ]))

    num_workers = 4
    train_loader = DataLoader(train_set, batch_size=128, num_workers=num_workers, pin_memory=True)
    val_loader = DataLoader(test_set, batch_size=128 * 4, num_workers=num_workers, pin_memory=True)

    train_acc_model(model_1, train_loader)
    connect(model_1, model_0, train_loader)
    train_acc_model(model_1, train_loader)

    torch.save(model_1.state_dict(), "/d1/xxx/TBQ_experiments/model_fp_1_aligned.pt")


def avg_experiment():
    DATASET_DIR = '/home/xxx/DBQ/'
    train_set = CIFAR100(DATASET_DIR, train=True, download=True,
                         transform=transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
                         ]))
    test_set = CIFAR100(DATASET_DIR, train=False, download=True,
                        transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
                        ]))

    num_workers = 4
    train_loader = DataLoader(train_set, batch_size=128, num_workers=num_workers, pin_memory=True)
    val_loader = DataLoader(test_set, batch_size=128 * 4, num_workers=num_workers, pin_memory=True)

    # average models
    model_a = load_model('/d1/xxx/TBQ_experiments/model_fp_0.pt')
    model_b = load_model('/d1/xxx/TBQ_experiments/model_fp_1_aligned.pt')
    # model_b = load_model('/d1/xxx/TBQ_experiments/model_fp_1.pt')

    avg_model = copy.deepcopy(model_a)

    for name in model_a.state_dict().keys():
        a = model_a.state_dict()[name]
        b = model_b.state_dict()[name]
        avg = avg_model.state_dict()[name]
        avg.copy_(0.5 * a + 0.5 * b)

    train_acc_model(avg_model, train_loader)


if __name__ == "__main__":
    # main()
    avg_experiment()
