import copy
import sys
from functools import partial

import graphviz
import torch
from tqdm import tqdm

from torch.utils.data import DataLoader
import torch.nn.functional as F
from networkx.algorithms.dag import topological_sort
import networkx as nx
import argparse
import pathlib
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import time

from reactnet import reactnet

sys.path.append('.')
from cifar_torch import cifar100
import logging_util


def make_graph(model):
    graph = nx.DiGraph()
    visited = set()

    def traverse(layer_name):
        visited.add(layer_name)
        for dep in model.dependent_layers(layer_name):
            graph.add_edge(dep, layer_name)
            if dep not in visited:
                traverse(dep)

    traverse('fc')
    return graph


def dataset_cov(train_loader, extract_patches, dim):
    count = 0
    mean = torch.zeros(dim, device='cuda', dtype=torch.float64)
    cov = torch.zeros(dim, dim, device='cuda', dtype=torch.float64)
    for X, y in tqdm(train_loader):
        X = X.cuda(non_blocking=True)
        patches = extract_patches(X)

        other_count = patches.double().size(1)
        other_mean = patches.double().mean(1)
        other_cov = torch.cov(patches)

        if count == 0:
            count = other_count
            mean.copy_(other_mean)
            cov.copy_(other_cov)
            continue

        merged_count = count + other_count
        count_corr = (other_count * count) / merged_count

        flat_mean_diff = other_mean - mean
        mean += flat_mean_diff * other_count / merged_count

        mean_diffs = torch.broadcast_to(flat_mean_diff, cov.shape).T
        cov *= (count / merged_count)
        cov += (
                other_cov * (count / merged_count)
                + mean_diffs * mean_diffs.T * (count_corr / merged_count)
        )
    return cov.float()


@torch.jit.script
def stochastic_rounding(x, T: float):
    s = torch.sign(x)
    x *= T
    orig = torch.where(torch.abs(x) < 1, (-s * x ** 2 + 2 * x), s)
    p = s * orig.abs()
    pr = p * .5 + .5
    sampled = (torch.rand_like(x) < pr) * 2 - 1
    return sampled


class Sign(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x: torch.Tensor, T):
        ctx.save_for_backward(x)
        # stochastic rounding with temperature parameter (deterministic when T -> inf)
        return torch.sign(x)  # stochastic_rounding(x, T)

    @staticmethod
    def backward(ctx, g):
        x, = ctx.saved_tensors
        # TODO replace with gradient estimator in reactnet
        return g * torch.clamp(2 * (1 - torch.abs(x)), min=0.0), None


@torch.jit.script
def quantization_loss(q, M, w_hat, lamba: float = 1e-5):
    q_hat = M @ q
    loss = ((q_hat - w_hat) ** 2).sum(0)
    loss += lamba * q.norm(2, dim=0)
    return loss


@torch.enable_grad()
def find_q(w, M, nsteps=2000, nsamples=20, xnor_net=False):
    if xnor_net:
        alpha = w.abs().mean(0, keepdim=True)
        q = alpha * torch.sign(w)
        return q

    w_hat = M @ w

    alpha = torch.nn.Parameter(w.abs().mean(0, keepdim=True))
    w = torch.nn.Parameter(w.clone() / alpha)

    initial_alpha_scale = alpha.data.norm(2).clone()
    print(f"initial_alpha={initial_alpha_scale}")

    opt = torch.optim.SGD([w, alpha], lr=0.01, momentum=0.9, nesterov=True)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=nsteps)

    best = torch.tensor(10. ** 9, device='cuda:0')
    best_w = copy.deepcopy(w.data)
    best_alpha = copy.deepcopy(alpha.data)
    for _ in range(nsamples):
        for step in range(nsteps):
            opt.zero_grad()
            T = 1 + step / nsteps * 20  # Temperature 1 --> 20

            loss = quantization_loss(
                alpha * Sign.apply(w, T),
                M, w_hat
            )
            lsum = loss.mean()
            lsum.backward()

            # evaluate actual loss
            lsum = quantization_loss(
                alpha * torch.sign(w),
                M, w_hat
            ).mean()

            # update best
            cond = torch.le(lsum, best)
            best.data.copy_(torch.where(cond, lsum, best))
            best_w.data.copy_(torch.where(cond, w.data, best_w.data))
            best_alpha.data.copy_(torch.where(cond, alpha.data, best_alpha.data))

            opt.step()
            scheduler.step()
        w.data.add_(torch.randn_like(w) * 0.01)
        print("best=", best.item())

    print(f"final best={best}, best_alpha={best_alpha.norm(2)}")
    return best_alpha, best_w


# knowledge distillation loss
@torch.jit.script
def distillation(logits, labels, teacher_scores, T: float, alpha: float):
    # distillation loss + classification loss
    # y: student
    # labels: hard label
    # teacher_scores: soft label
    task_loss = F.cross_entropy(logits, labels, reduction='none')
    teacher_loss = F.kl_div(
        F.log_softmax(logits / T, dim=-1),
        F.log_softmax(teacher_scores / T, dim=-1),
        reduction='none',
        log_target=True
    ).sum(1)
    return task_loss * (1 - alpha) + teacher_loss * (2 * T ** 2 + alpha)


@torch.enable_grad()
def train_model(model, cur_params, weight_params, bn_params, inv_sigma,
                teacher, train_loader, n_steps, total_steps, last_step,
                lr=0.01, T=20.0, alpha=0.7, weight_decay=1/500,):
    model.train()
    params = [*weight_params, *bn_params, *cur_params]
    opt = torch.optim.SGD([{'params': params, 'initial_lr': lr}],
                          lr=lr, momentum=0.9, nesterov=True)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=total_steps, last_epoch=last_step)
    steps = 0
    while True:
        count = 0
        loss_accum = 0
        correct_count = 0
        for X, y in tqdm(train_loader, desc="Train"):
            X = X.cuda(non_blocking=True)
            y = y.cuda(non_blocking=True).flatten()
            model.zero_grad()

            logits = model(X)
            with torch.no_grad():
                teacher_scores = teacher(X)
            loss = distillation(logits, y, teacher_scores, T, alpha)
            loss.mean().backward()

            # apply EmpCov to the first layer
            #w = cur_w.data.view(cur_w.shape[0], -1).T  # (in, out)
            #cur_w.grad.add_((inv_sigma @ w).T.reshape_as(cur_w))
            for cur_w in cur_params:
                cur_w.grad.add_(cur_w.data, alpha=weight_decay)
            # weight decay for the rest of the layers
            for weight in [*weight_params]:
                weight.grad.add_(weight.data, alpha=weight_decay)
            # FIXME wd on the bn affine weights
            for weight in bn_params:
                weight.grad.add_(weight.data, alpha=weight_decay)

            opt.step()
            count += X.shape[0]
            loss_accum += loss.sum().data
            correct_count += (torch.argmax(logits, dim=1) == y).sum()
            scheduler.step()

            steps += 1
            if steps > n_steps:
                return
        tqdm.write(f'{(loss_accum / count).item()}')
        accuracy = (correct_count / count).item()
        tqdm.write(f'{accuracy:.4f}')


@torch.no_grad()
def extract_patches(model, layer, X):
    X = model.inputs_for(layer, X)
    if isinstance(layer, torch.nn.Conv2d):
        X = F.unfold(
            X,
            kernel_size=layer.kernel_size,
            dilation=layer.dilation,
            padding=layer.padding,
            stride=layer.stride,
        )  # (batch, ch*ks*ks, height*width)
        X = X.permute(1, 0, 2)  # (ch*ks*ks, batch, height*width)
        X = X.reshape(X.shape[0], -1)  # (dim, batch*height*width)
    return X


@torch.no_grad()
def quantize(model, train_loader, val_loader, block_iters, finetune_epochs, xnor_net=False, skip_until=None):
    graph = make_graph(model)
    modules = dict(model.named_modules())
    weight_params = set(sum([list(m.parameters(recurse=False))
                             for m in model.modules() if not isinstance(m, torch.nn.BatchNorm2d)], []))
    bn_params = set(sum([list(m.parameters(recurse=False))
                         for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)], []))
    print(f"{len(weight_params)} weight params, {len(bn_params)} bn params")

    teacher = copy.deepcopy(model)
    total_steps = len(train_loader) * finetune_epochs
    n_steps = total_steps // block_iters

    # for layer_name in topological_sort(graph):
    for it in range(block_iters):
        print(f"Iteration {it}.")
        group_weight_params = {p for p in weight_params}

        if it > 0:
            print("Block initial fine-tuning")
            train_model(model, [], group_weight_params, bn_params,
                        None, teacher, train_loader,
                        n_steps=total_steps, total_steps=total_steps, last_step=0)
            test_model(model, val_loader)

        for layer_name in topological_sort(graph):
            if layer_name == skip_until:
                skip_until = None

            # skip last layer for now
            if layer_name == 'fc':
                continue

            print("Quantizing", layer_name)
            if layer_name.endswith('_down'):
                layer = (modules[layer_name + '1'], modules[layer_name + '2'])
                cur_params = [layer[0].weight, layer[1].weight]
                groups = layer[0].groups
                get_patches = partial(extract_patches, model, modules[layer_name + '1'])

                # calculate dimensions
                w = torch.cat([layer[0].weight.data, layer[1].weight.data], dim=0)
                dim = np.prod(w.shape[1:])
                w = w.view(w.size(0), dim)  # (out, in)

            else:
                layer = modules[layer_name]
                cur_params = [layer.weight]
                groups = layer.groups
                get_patches = partial(extract_patches, model, layer)

                # calculate dimensions
                w = layer.weight.data
                dim = np.prod(w.shape[1:])
                w = w.view(w.size(0), dim)  # (out, in)

            # freeze layer
            for p in cur_params:
                group_weight_params.remove(p)

            # FIXME skip depthwise layer for now
            if skip_until is None and groups == 1:
                # calculate input covariance
                model.eval()
                cov = dataset_cov(train_loader, get_patches, dim)
                _, s, V = torch.svd(cov)
                M = torch.diag(s ** 0.4) @ V.t()

                # fine-tune model with EmpCov
                if layer_name != 'conv1' or it > 0:
                    print("Fine-tuning")
                    train_model(model, cur_params, group_weight_params, bn_params,
                                None, teacher, train_loader,
                                n_steps=n_steps, total_steps=total_steps, last_step=it * n_steps)
                    test_model(model, val_loader)

            # FIXME skip depthwise layer for now
            if groups > 1:
                continue
            # skip quantization until specified layer is reached
            if skip_until is not None:
                continue

            # quantize layer
            print("Finding quantized values")
            alpha, w_prime = find_q(w.t(), M, xnor_net=xnor_net, nsamples=20 // block_iters)

            # apply quantized layer weight
            if layer_name.endswith('_down'):
                w = (alpha * torch.sign(w_prime)).t()
                layer[0].weight.data.copy_(w[:w.shape[0]//2].reshape(layer[0].weight.shape))
                layer[1].weight.data.copy_(w[w.shape[0]//2:].reshape(layer[1].weight.shape))
            else:
                w.copy_((alpha * torch.sign(w_prime)).t())

            test_model(model, val_loader)

        print("Block final fine-tuning")
        train_model(model, [], group_weight_params, bn_params,
                    None, teacher, train_loader,
                    n_steps=total_steps, total_steps=total_steps, last_step=0)
        test_model(model, val_loader)

    # update frozen layers
    weight_params = group_weight_params


def load_model(path):
    model = reactnet(num_classes=100).cuda()
    checkpoint = torch.load(path, map_location='cpu')
    print(checkpoint['epoch'], checkpoint['best_top1_acc'])
    state_dict = {k.replace('module.', ''): v for k, v in checkpoint['state_dict'].items()}
    model.load_state_dict(state_dict, strict=False)
    return model


def test_model(model, val_loader: DataLoader):
    model.eval()
    print("Test")
    count = 0
    loss_accum = 0
    correct_count = 0
    criterion = torch.nn.CrossEntropyLoss(reduction='sum')
    with torch.no_grad():
        for X, y in val_loader:
            X = X.cuda(non_blocking=True)
            y = y.cuda(non_blocking=True).flatten()
            logits = model(X)
            loss = criterion(logits, y)
            count += X.shape[0]
            loss_accum += loss.data
            correct_count += (torch.argmax(logits, dim=1) == y).sum()
        loss = (loss_accum / count).item()
        acc = (correct_count / count).item()
        print(f'{loss}')
        print(f'{acc:.4f}')
    return loss, acc


def train_acc_model(model, train_loader: DataLoader):
    model.train()
    print("Train Acc")
    count = 0
    loss_accum = 0
    correct_count = 0
    criterion = torch.nn.CrossEntropyLoss(reduction='sum')
    with torch.no_grad():
        for X, y in train_loader:
            X = X.cuda(non_blocking=True)
            y = y.cuda(non_blocking=True).flatten()
            logits = model(X)
            loss = criterion(logits, y)
            count += X.shape[0]
            loss_accum += loss.data
            correct_count += (torch.argmax(logits, dim=1) == y).sum()
        loss = (loss_accum / count).item()
        acc = (correct_count / count).item()
        print(f'{loss}')
        print(f'{acc:.4f}')
    return loss, acc


def render_dep_graph(model):
    # render dependency graph
    dot = graphviz.Graph('dep-graph')
    marked = set()

    def make_dep_tree(model, layer):
        dot.node(layer, layer)
        marked.add(layer)
        for dep in model.dependent_layers(layer):
            dot.edge(dep, layer)
            if dep not in marked:
                make_dep_tree(model, dep)

    make_dep_tree(model, 'fc')
    dot.render('dep-graph-reactnet')


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--xnor_net', default=False, action='store_true')
    parser.add_argument('--skip_until', default=None, type=str)
    parser.add_argument('--load_path', default='/d1/xxx/DBQ/baseline/1_step1/models/model_best.pth.tar', type=str)
    parser.add_argument('--save_path', default=None, type=str)
    parser.add_argument('--block_iters', default=5, type=int)
    parser.add_argument('--finetune_epochs', default=4, type=int)
    args = parser.parse_args()

    writer = SummaryWriter(comment='greedy-reactnet')
    logging_util.setup_logging(pathlib.Path(writer.log_dir) / 'log.txt')

    print("Args=", args)

    if args.save_path is None:
        load_path = pathlib.Path(args.load_path)
        args.save_path = pathlib.Path(writer.log_dir) / f"{load_path.stem}-quantized.pt"
        print("Will save to ", args.save_path)
    model = load_model(args.load_path)
    print(model)

    render_dep_graph(model)

    # load dataset
    train_loader, val_loader = cifar100(batch_size=200, workers=4)

    # evaluated fp model
    train_acc_model(model, train_loader)
    test_model(model, val_loader)

    # quantize model
    quantize(model, train_loader, val_loader,
             block_iters=args.block_iters,
             xnor_net=args.xnor_net,
             skip_until=args.skip_until,
             finetune_epochs=args.finetune_epochs)

    # evaluate quantized model
    train_acc_model(model, train_loader)
    test_model(model, val_loader)

    torch.save(model.state_dict(), args.save_path)


if __name__ == '__main__':
    main()
