import copy
import sys
from functools import partial

import logging_util
import torch
from tqdm import tqdm

import mobilenetv2

from torchvision.datasets import CIFAR100
from torchvision import transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
from networkx.algorithms.dag import topological_sort
import networkx as nx
import argparse
import pathlib
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from cifar_torch import cifar100


def make_graph(model):
    graph = nx.DiGraph()
    visited = set()

    def traverse(layer_name):
        visited.add(layer_name)
        for dep in model.dependent_layers(layer_name, include_depthwise=True):
            graph.add_edge(dep, layer_name)
            if dep not in visited:
                traverse(dep)

    traverse('linear')
    return graph


def dataset_cov(train_loader, extract_patches, dim):
    count = 0
    mean = torch.zeros(dim, device='cuda', dtype=torch.float64)
    cov = torch.zeros(dim, dim, device='cuda', dtype=torch.float64)
    for X, y in tqdm(train_loader):
        X = X.cuda(non_blocking=True)
        patches = extract_patches(X)

        other_count = patches.double().size(1)
        other_mean = patches.double().mean(1)
        other_cov = torch.cov(patches)

        if count == 0:
            count = other_count
            mean.copy_(other_mean)
            cov.copy_(other_cov)
            continue

        merged_count = count + other_count
        count_corr = (other_count * count) / merged_count

        flat_mean_diff = other_mean - mean
        mean += flat_mean_diff * other_count / merged_count

        mean_diffs = torch.broadcast_to(flat_mean_diff, cov.shape).T
        cov *= (count / merged_count)
        cov += (
                other_cov * (count / merged_count)
                + mean_diffs * mean_diffs.T * (count_corr / merged_count)
        )
    return cov.float()


@torch.jit.script
def stochastic_rounding(x, T: float):
    s = torch.sign(x)
    x *= T
    orig = torch.where(torch.abs(x) < 1, (-s * x ** 2 + 2 * x), s)
    p = s * orig.abs()
    pr = p * .5 + .5
    sampled = (torch.rand_like(x) < pr) * 2 - 1
    return sampled


class Sign(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x: torch.Tensor, T):
        ctx.save_for_backward(x)
        # stochastic rounding with temperature parameter (deterministic when T -> inf)
        return torch.sign(x)  # stochastic_rounding(x, T)

    @staticmethod
    def backward(ctx, g):
        x, = ctx.saved_tensors
        # TODO replace with gradient estimator in reactnet
        return g * torch.clamp(2 * (1 - torch.abs(x)), min=0.0), None


@torch.jit.script
def quantization_loss(q, M, w_hat, lamba: float = 1e-5):
    q_hat = M @ q
    loss = ((q_hat - w_hat) ** 2).sum(0)
    loss += lamba * q.norm(2, dim=0)
    return loss


@torch.enable_grad()
def find_q(w, M, nsteps=2000, nsamples=20, xnor_net=False):
    if xnor_net:
        alpha = w.abs().mean(0, keepdim=True)
        q = alpha * torch.sign(w)
        return q

    w_hat = M @ w

    alpha = torch.nn.Parameter(w.abs().mean(0, keepdim=True))
    w = torch.nn.Parameter(w.clone() / alpha)

    initial_alpha_scale = alpha.data.norm(2).clone()
    print(f"initial_alpha={initial_alpha_scale}")

    opt = torch.optim.SGD([w, alpha], lr=0.01, momentum=0.9, nesterov=True)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=nsteps)

    best = torch.tensor(10. ** 9, device='cuda:0')
    best_w = copy.deepcopy(w.data)
    best_alpha = copy.deepcopy(alpha.data)
    for _ in range(nsamples):
        for step in range(nsteps):
            opt.zero_grad()
            T = 1 + step / nsteps * 20  # Temperature 1 --> 20

            loss = quantization_loss(
                alpha * Sign.apply(w, T),
                M, w_hat
            )
            lsum = loss.mean()
            lsum.backward()

            # evaluate actual loss
            lsum = quantization_loss(
                alpha * torch.sign(w),
                M, w_hat
            ).mean()

            # update best
            cond = torch.le(lsum, best)
            best.data.copy_(torch.where(cond, lsum, best))
            best_w.data.copy_(torch.where(cond, w.data, best_w.data))
            best_alpha.data.copy_(torch.where(cond, alpha.data, best_alpha.data))

            opt.step()
            scheduler.step()
        w.data.add_(torch.randn_like(w) * 0.01)
        print(best)

    print(f"best={best}, best_alpha={best_alpha.norm(2)}")
    return best_alpha, best_w


# knowledge distillation loss
@torch.jit.script
def distillation(logits, labels, teacher_scores, T: float, alpha: float):
    # distillation loss + classification loss
    # y: student
    # labels: hard label
    # teacher_scores: soft label
    task_loss = F.cross_entropy(logits, labels, reduction='none')
    teacher_loss = F.kl_div(
        F.log_softmax(logits / T, dim=-1),
        F.log_softmax(teacher_scores / T, dim=-1),
        reduction='none',
        log_target=True
    ).sum(1)
    return task_loss * (1 - alpha) + teacher_loss * (2 * T ** 2 + alpha)


@torch.enable_grad()
def train_model(model, cur_w, weight_params, bn_params, inv_sigma,
                teacher, train_loader, n_steps, total_steps, last_step,
                lr=0.001, T=20.0, alpha=0.7, weight_decay=1/500,):
    model.train()
    params = [*weight_params, *bn_params]
    if cur_w is not None:
        params.append(cur_w)
    opt = torch.optim.SGD([{'params': params, 'initial_lr': lr}],
                          lr=lr, momentum=0.9, nesterov=True)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=total_steps, last_epoch=last_step)
    steps = 0
    while True:
        count = 0
        loss_accum = 0
        correct_count = 0
        for X, y in tqdm(train_loader, desc="Train"):
            X = X.cuda(non_blocking=True)
            y = y.cuda(non_blocking=True).flatten()
            model.zero_grad()

            logits = model(X)
            with torch.no_grad():
                teacher_scores = teacher(X)
            loss = distillation(logits, y, teacher_scores, T, alpha)
            loss.mean().backward()

            # apply EmpCov to the first layer
            #w = cur_w.data.view(cur_w.shape[0], -1).T  # (in, out)
            #cur_w.grad.add_((inv_sigma @ w).T.reshape_as(cur_w))
            if cur_w is not None:
                cur_w.grad.add_(cur_w.data, alpha=weight_decay)
            # weight decay for the rest of the layers
            for weight in [*weight_params]:
                weight.grad.add_(weight.data, alpha=weight_decay)
            # FIXME wd on the bn affine weights
            for weight in bn_params:
                weight.grad.add_(weight.data, alpha=weight_decay)

            opt.step()
            count += X.shape[0]
            loss_accum += loss.sum().data
            correct_count += (torch.argmax(logits, dim=1) == y).sum()
            scheduler.step()

            steps += 1
            if steps > n_steps:
                return
        tqdm.write(f'{(loss_accum / count).item()}')
        accuracy = (correct_count / count).item()
        tqdm.write(f'{accuracy:.4f}')


@torch.no_grad()
def extract_patches(model, layer, X):
    X = model.inputs_for(layer, X)
    if isinstance(layer, torch.nn.Conv2d):
        X = F.unfold(
            X,
            kernel_size=layer.kernel_size,
            dilation=layer.dilation,
            padding=layer.padding,
            stride=layer.stride,
        )  # (batch, ch*ks*ks, height*width)
        X = X.permute(1, 0, 2)  # (ch*ks*ks, batch, height*width)
        X = X.reshape(X.shape[0], -1)  # (dim, batch*height*width)
    return X


LAYERS = [
    ['conv1',
     'layers.0.conv1',
     'layers.0.shortcut.0',
     'layers.0.conv3',
     'layers.1.conv1',
     'layers.1.shortcut.0',
     'layers.1.conv3',
     'layers.2.conv1',
     'layers.2.conv3',
     'layers.3.conv1',
     'layers.3.conv3'],

    ['layers.4.conv1',
     'layers.4.conv3',
     'layers.5.conv1',
     'layers.5.conv3',
     'layers.6.conv1',
     'layers.6.conv3'],

    ['layers.7.conv1',
     'layers.7.conv3',
     'layers.8.conv1',
     'layers.8.conv3',
     'layers.9.conv1',
     'layers.9.conv3',
     'layers.10.conv1',
     'layers.10.shortcut.0',
     'layers.10.conv3',
     'layers.11.conv1',
     'layers.11.conv3',
     'layers.12.conv1',
     'layers.12.conv3',
     'layers.13.conv1',
     'layers.13.conv3'],

    ['layers.14.conv1',
     'layers.14.conv3',
     'layers.15.conv1',
     'layers.15.conv3',
     'layers.16.shortcut.0',
     'layers.16.conv1',
     'layers.16.conv3',
     'conv2', 'linear']
]


@torch.no_grad()
def quantize(model, train_loader, val_loader, block_iters, finetune_epochs, xnor_net=False, skip_until=None):
    graph = make_graph(model)
    modules = dict(model.named_modules())
    weight_params = set(sum([list(m.parameters(recurse=False))
                             for m in model.modules() if not isinstance(m, torch.nn.BatchNorm2d)], []))
    bn_params = set(sum([list(m.parameters(recurse=False))
                         for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)], []))
    print(f"{len(weight_params)} weight params, {len(bn_params)} bn params")

    teacher = copy.deepcopy(model)
    total_steps = len(train_loader) * finetune_epochs
    n_steps = total_steps // block_iters

    # for layer_name in topological_sort(graph):
    for layer_group in LAYERS:
        for it in range(block_iters):
            print(f"Iteration {it}.")
            group_weight_params = {p for p in weight_params}

            for layer_name in layer_group:
                if layer_name == skip_until:
                    skip_until = None

                print("Quantizing", layer_name)
                layer = modules[layer_name]

                # skip last layer for now
                if isinstance(layer, torch.nn.Linear):
                    continue

                # calculate dimensions
                w = layer.weight.data
                dim = np.prod(w.shape[1:])
                w = w.view(w.size(0), dim)  # (out, in)

                # freeze layer
                cur_param = layer.weight
                group_weight_params.remove(cur_param)

                # FIXME skip depthwise layer for now
                if skip_until is None and layer.groups == 1:
                    # calculate input covariance
                    model.eval()
                    cov = dataset_cov(train_loader, partial(extract_patches, model, layer), dim)
                    _, s, V = torch.svd(cov)
                    M = torch.diag(s ** 0.4) @ V.t()

                    # fine-tune model with EmpCov
                    if layer_name != 'conv1' or it > 0:
                        print("Fine-tuning")
                        train_model(model, cur_param, group_weight_params, bn_params,
                                    None, teacher, train_loader,
                                    n_steps=n_steps, total_steps=total_steps, last_step=it * n_steps)
                        test_model(model, val_loader)

                """
                # FIXME freeze bn layers as well
                if 'shortcut' in layer_name:
                    bn_name = layer_name[:-1] + "1"
                elif 'conv' in layer_name:
                    bn_name = layer_name.replace('conv', 'bn')
                else:
                    bn_name = None
                if bn_name is not None:
                    bn = modules[bn_name]
                    bn_params.remove(bn.weight)
                    bn_params.remove(bn.bias)
                """

                # FIXME skip depthwise layer for now
                if layer.groups > 1:
                    continue
                # skip quantization until specified layer is reached
                if skip_until is not None:
                    continue

                # quantize layer
                print("Finding quantized values")
                alpha, w_prime = find_q(w.t(), M, xnor_net=xnor_net, nsamples=20 // block_iters)

                # apply quantized layer weight
                #if it < block_iters - 1:
                #    w.copy_((alpha * w_prime).t())
                #else:
                w.copy_((alpha * torch.sign(w_prime)).t())

                test_model(model, val_loader)

            print("Block final fine-tuning")
            train_model(model, None, group_weight_params, bn_params,
                        None, teacher, train_loader,
                        n_steps=total_steps, total_steps=total_steps, last_step=0)
            test_model(model, val_loader)

        # update frozen layers
        weight_params = group_weight_params


def load_model(path):
    model = mobilenetv2.MobileNetV2(num_classes=100, final_bias_trick=False).cuda()
    state_dict = torch.load(path, map_location='cpu')
    state_dict = {k: p for k, p in state_dict.items() if 'quantizer' not in k}
    if 'linear.bias' not in state_dict:
        state_dict['linear.bias'] = state_dict['linear.weight'][:, -1]
        state_dict['linear.weight'] = state_dict['linear.weight'][:, :-1]
    model.load_state_dict(state_dict)
    return model


def test_model(model, val_loader: DataLoader):
    model.eval()
    print("Test")
    count = 0
    loss_accum = 0
    correct_count = 0
    criterion = torch.nn.CrossEntropyLoss(reduction='sum')
    with torch.no_grad():
        for X, y in val_loader:
            X = X.cuda(non_blocking=True)
            y = y.cuda(non_blocking=True).flatten()
            logits = model(X)
            loss = criterion(logits, y)
            count += X.shape[0]
            loss_accum += loss.data
            correct_count += (torch.argmax(logits, dim=1) == y).sum()
        loss = (loss_accum / count).item()
        acc = (correct_count / count).item()
        print(f'{loss}')
        print(f'{acc:.4f}')
    return loss, acc


def train_acc_model(model, train_loader: DataLoader):
    model.train()
    print("Train Acc")
    count = 0
    loss_accum = 0
    correct_count = 0
    criterion = torch.nn.CrossEntropyLoss(reduction='sum')
    with torch.no_grad():
        for X, y in train_loader:
            X = X.cuda(non_blocking=True)
            y = y.cuda(non_blocking=True).flatten()
            logits = model(X)
            loss = criterion(logits, y)
            count += X.shape[0]
            loss_accum += loss.data
            correct_count += (torch.argmax(logits, dim=1) == y).sum()
        loss = (loss_accum / count).item()
        acc = (correct_count / count).item()
        print(f'{loss}')
        print(f'{acc:.4f}')
    return loss, acc


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--xnor_net', default=False, action='store_true')
    parser.add_argument('--skip_until', default=None, type=str)
    parser.add_argument('--load_path', default='/d1/xxx/TBQ/TBQ_experiments/model_fp_0.pt', type=str)
    parser.add_argument('--save_path', default=None, type=str)
    parser.add_argument('--block_iters', default=20, type=int)
    parser.add_argument('--finetune_epochs', default=4, type=int)
    args = parser.parse_args()

    writer = SummaryWriter(comment='greedy')
    logging_util.setup_logging(pathlib.Path(writer.log_dir) / 'log.txt')

    print("Args=", args)

    if args.save_path is None:
        load_path = pathlib.Path(args.load_path)
        args.save_path = pathlib.Path(writer.log_dir) / f"{load_path.stem}-quantized.pt"
        print("Will save to ", args.save_path)
    model = load_model(args.load_path)

    # render_dep_graph(model_0)

    # load dataset
    train_loader, val_loader = cifar100(batch_size=128, workers=4)

    # evaluated fp model
    train_acc_model(model, train_loader)
    test_model(model, val_loader)

    # quantize model
    quantize(model, train_loader, val_loader,
             block_iters=args.block_iters,
             xnor_net=args.xnor_net,
             skip_until=args.skip_until,
             finetune_epochs=args.finetune_epochs)

    # evaluate quantized model
    train_acc_model(model, train_loader)
    test_model(model, val_loader)

    torch.save(model.state_dict(), args.save_path)


if __name__ == '__main__':
    main()
