import sys
import copy
import math
import socket
import pathlib
import argparse
import itertools
from datetime import datetime
from functools import partial

import os
import torch
import graphviz
import numpy as np
from torchvision import transforms, datasets
from tqdm import tqdm
import networkx as nx
import more_itertools
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.utils.tensorboard import SummaryWriter
from networkx.algorithms.dag import topological_sort
from torch.utils.data import DataLoader, DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP

from reactnet import reactnet

sys.path.append('.')
import logging_util
from baseline_imagenet.utils.utils import Lighting

torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True


def make_graph(model):
    graph = nx.DiGraph()
    visited = set()

    def traverse(layer_name):
        visited.add(layer_name)
        for dep in model.dependent_layers(layer_name):
            graph.add_edge(dep, layer_name)
            if dep not in visited:
                traverse(dep)

    traverse('fc')
    return graph


def dataset_cov(train_loader, extract_patches, dim, sample_portion=1.0):
    count = 0
    mean = torch.zeros(dim, device='cuda', dtype=torch.float64)
    cov = torch.zeros(dim, dim, device='cuda', dtype=torch.float64)
    iters = math.ceil(len(train_loader) * sample_portion)
    for X, y in itertools.islice(tqdm(train_loader), iters):
        X = X.cuda(non_blocking=True)
        patches = extract_patches(X)

        other_count = patches.double().size(1)
        other_mean = patches.double().mean(1)
        other_cov = torch.cov(patches)

        if count == 0:
            count = other_count
            mean.copy_(other_mean)
            cov.copy_(other_cov)
            continue

        merged_count = count + other_count
        count_corr = (other_count * count) / merged_count

        flat_mean_diff = other_mean - mean
        mean += flat_mean_diff * other_count / merged_count

        mean_diffs = torch.broadcast_to(flat_mean_diff, cov.shape).T
        cov *= (count / merged_count)
        cov += (
                other_cov * (count / merged_count)
                + mean_diffs * mean_diffs.T * (count_corr / merged_count)
        )
    return cov.float()


@torch.jit.script
def stochastic_rounding(x, T: float):
    s = torch.sign(x)
    x *= T
    orig = torch.where(torch.abs(x) < 1, (-s * x ** 2 + 2 * x), s)
    p = s * orig.abs()
    pr = p * .5 + .5
    sampled = (torch.rand_like(x) < pr) * 2 - 1
    return sampled


class Sign(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x: torch.Tensor, T):
        ctx.save_for_backward(x)
        # stochastic rounding with temperature parameter (deterministic when T -> inf)
        return torch.sign(x)  # stochastic_rounding(x, T)

    @staticmethod
    def backward(ctx, g):
        x, = ctx.saved_tensors
        # TODO replace with gradient estimator in reactnet
        return g * torch.clamp(2 * (1 - torch.abs(x)), min=0.0), None


@torch.jit.script
def quantization_loss(q, M, w_hat, lamba: float = 1e-5):
    q_hat = M @ q
    loss = ((q_hat - w_hat) ** 2).sum(0)
    loss += lamba * q.norm(2, dim=0)
    return loss


@torch.enable_grad()
def find_q(w, M, nsteps=2000, nsamples=20, xnor_net=False):
    if xnor_net:
        alpha = w.abs().mean(0, keepdim=True)
        q = alpha * torch.sign(w)
        return q

    w_hat = M @ w

    alpha = torch.nn.Parameter(w.abs().mean(0, keepdim=True))
    w = torch.nn.Parameter(w.clone() / alpha)

    initial_alpha_scale = alpha.data.norm(2).clone()
    print(f"initial_alpha={initial_alpha_scale}")

    opt = torch.optim.SGD([w, alpha], lr=0.01, momentum=0.9, nesterov=True)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=nsteps)

    best = torch.tensor(10. ** 9, device='cuda')
    best_w = copy.deepcopy(w.data)
    best_alpha = copy.deepcopy(alpha.data)
    for _ in range(nsamples):
        for step in range(nsteps):
            opt.zero_grad()
            T = 1 + step / nsteps * 20  # Temperature 1 --> 20

            loss = quantization_loss(
                alpha * Sign.apply(w, T),
                M, w_hat
            )
            lsum = loss.mean()
            lsum.backward()

            # evaluate actual loss
            lsum = quantization_loss(
                alpha * torch.sign(w),
                M, w_hat
            ).mean()

            # update best
            cond = torch.le(lsum, best)
            best.data.copy_(torch.where(cond, lsum, best))
            best_w.data.copy_(torch.where(cond, w.data, best_w.data))
            best_alpha.data.copy_(torch.where(cond, alpha.data, best_alpha.data))

            opt.step()
            scheduler.step()
        w.data.add_(torch.randn_like(w) * 0.01)
        print("best=", best.item())

    best_list = [torch.zeros_like(best) for _ in range(dist.get_world_size())]
    dist.all_gather(best_list, best)
    best_idx = torch.argmin(torch.stack(best_list), dim=0).item()

    dist.broadcast(best_alpha, best_idx)
    dist.broadcast(best_w, best_idx)

    print(f"final best={best}, best_alpha={best_alpha.norm(2)}")
    return best_alpha, best_w


# knowledge distillation loss
@torch.jit.script
def distillation(logits, labels, teacher_scores, T: float, alpha: float):
    # distillation loss + classification loss
    # y: student
    # labels: hard label
    # teacher_scores: soft label
    task_loss = F.cross_entropy(logits, labels, reduction='none')
    teacher_loss = F.kl_div(
        F.log_softmax(logits / T, dim=-1),
        F.log_softmax(teacher_scores / T, dim=-1),
        reduction='none',
        log_target=True
    ).sum(1)
    return task_loss * (1 - alpha) + teacher_loss * (2 * T ** 2 + alpha)


@torch.enable_grad()
def train_model(model, opt, teacher, train_loader, n_steps,
                T=20.0, alpha=0.7):
    teacher = DDP(teacher)
    teacher.eval()
    model = DDP(model)
    model.train()
    steps = 0
    while True:
        # for shuffling
        model.module.epoch += 1
        train_loader.sampler.set_epoch(model.module.epoch)

        count = torch.tensor(0, device='cuda')
        loss_accum = torch.tensor(0.0, device='cuda')
        correct_count = torch.tensor(0, device='cuda')
        for X, y in tqdm(train_loader, desc="Train"):
            X = X.cuda(non_blocking=True)
            y = y.cuda(non_blocking=True).flatten()
            model.zero_grad()

            logits = model(X)
            with torch.no_grad():
                teacher_scores = teacher(X)
            loss = distillation(logits, y, teacher_scores, T, alpha)
            loss.mean().backward()

            # apply EmpCov to the first layer
            #w = cur_w.data.view(cur_w.shape[0], -1).T  # (in, out)
            #cur_w.grad.add_((inv_sigma @ w).T.reshape_as(cur_w))

            opt.step()
            count += X.shape[0]
            loss_accum += loss.sum().data
            correct_count += (torch.argmax(logits, dim=1) == y).sum()

            steps += 1
            if steps > n_steps:
                return
        dist.all_reduce(loss_accum)
        dist.all_reduce(correct_count)
        dist.all_reduce(count)
        tqdm.write(f'{(loss_accum / count).item()}')
        accuracy = (correct_count / count).item()
        tqdm.write(f'{accuracy:.4f}')


@torch.no_grad()
def extract_patches(model, layer, X):
    X = model.inputs_for(layer, X)
    if isinstance(layer, torch.nn.Conv2d):
        X = F.unfold(
            X,
            kernel_size=layer.kernel_size,
            dilation=layer.dilation,
            padding=layer.padding,
            stride=layer.stride,
        )  # (batch, ch*ks*ks, height*width)
        X = X.permute(1, 0, 2)  # (ch*ks*ks, batch, height*width)
        X = X.reshape(X.shape[0], -1)  # (dim, batch*height*width)
    return X


@torch.no_grad()
def quantize(model, train_loader, val_loader, block_iters, finetune_epochs, world_size,
             xnor_net=False, skip_until=None,
             block_size=7, sample_portion=1.0, log_dir=None, findq_iters=20, rank=0):
    graph = make_graph(model)
    modules = dict(model.named_modules())
    weight_params = sum([list(m.parameters(recurse=False)) for m in model.modules()
                         if not isinstance(m, torch.nn.BatchNorm2d)], [])
    bn_params = sum([list(m.parameters(recurse=False)) for m in model.modules()
                     if isinstance(m, torch.nn.BatchNorm2d)], [])
    print(f"{len(weight_params)} weight params, {len(bn_params)} bn params")
    model.epoch = 0

    teacher = copy.deepcopy(model)
    total_steps = len(train_loader) * finetune_epochs
    n_steps = total_steps // block_iters

    # for layer_name in topological_sort(graph):
    layers = list(topological_sort(graph))
    layer_groups = list(more_itertools.chunked_even(layers, block_size))
    print("Layer groups", layer_groups)

    for group_idx, layer_group in enumerate(layer_groups):

        opt = torch.optim.Adam(
            [{'params': bn_params},
             {'params': weight_params, 'weight_decay': 1e-5}],
            lr=1.25e-3 / 32
        )

        for it in range(block_iters):
            print(f"Iteration {it}.")

            if it > 0 or group_idx > 0:
                print("Block initial fine-tuning")
                train_model(model, opt, teacher, train_loader, n_steps=total_steps)
                test_model(model, val_loader)

            new_frozen = []
            for layer_name in layer_group:
                if layer_name == skip_until:
                    skip_until = None

                # skip last layer for now
                if layer_name == 'fc':
                    continue

                print("Quantizing", layer_name)
                if layer_name.endswith('_down'):
                    layer = (modules[layer_name + '1'], modules[layer_name + '2'])
                    cur_params = [layer[0].weight, layer[1].weight]
                    groups = layer[0].groups
                    get_patches = partial(extract_patches, model, modules[layer_name + '1'])

                    # calculate dimensions
                    w = torch.cat([layer[0].weight.data, layer[1].weight.data], dim=0)
                    dim = np.prod(w.shape[1:])
                    w = w.view(w.size(0), dim)  # (out, in)

                else:
                    layer = modules[layer_name]
                    cur_params = [layer.weight]
                    groups = layer.groups
                    get_patches = partial(extract_patches, model, layer)

                    # calculate dimensions
                    w = layer.weight.data
                    dim = np.prod(w.shape[1:])
                    w = w.view(w.size(0), dim)  # (out, in)

                # freeze layer
                for p in cur_params:
                    new_frozen.append(p)
                    p.requires_grad = False
                    p.grad = None

                # FIXME skip depthwise layer for now
                if groups > 1:
                    continue
                # skip quantization until specified layer is reached
                if skip_until is not None:
                    continue

                # calculate input covariance
                model.eval()
                cov = dataset_cov(train_loader, get_patches, dim, sample_portion)
                _, s, V = torch.svd(cov)
                M = torch.diag(s ** 0.4) @ V.t()

                # quantize layer
                print("Finding quantized values")
                alpha, w_prime = find_q(
                    w.T, M,
                    xnor_net=xnor_net,
                    nsamples=findq_iters // block_iters
                )

                # apply quantized layer weight
                if layer_name.endswith('_down'):
                    w = (alpha * torch.sign(w_prime)).t()
                    layer[0].weight.data.copy_(w[:w.shape[0]//2].reshape(layer[0].weight.shape))
                    layer[1].weight.data.copy_(w[w.shape[0]//2:].reshape(layer[1].weight.shape))
                else:
                    w.copy_((alpha * torch.sign(w_prime)).t())

                test_model(model, val_loader)

                print("Fine-tuning")
                train_model(model, opt, teacher, train_loader, n_steps=n_steps)
                test_model(model, val_loader)

            print("Block final fine-tuning")
            train_model(model, opt, teacher, train_loader, n_steps=total_steps)
            test_model(model, val_loader)

            if rank == 0:
                print("Saving to ", log_dir / f"block-{group_idx}.pth")
                torch.save(model.state_dict(), log_dir / f"block-{group_idx}.pth")

            # unfreeze block layers
            if it < block_iters - 1:
                print(f"Unfreezing {len(new_frozen)} layers")
                for p in new_frozen:
                    p.requires_grad = True
            else:
                weight_params = [p for p in weight_params if p.requires_grad]


def load_model(path):
    model = reactnet(num_classes=1000).cuda()
    checkpoint = torch.load(path, map_location='cpu')
    print(checkpoint['epoch'], checkpoint['best_top1_acc'])
    state_dict = {k.replace('module.', ''): v for k, v in checkpoint['state_dict'].items()}
    model.load_state_dict(state_dict, strict=False)
    return model


def train_acc_model(model, train_loader: DataLoader):
    model = DDP(model)
    model.train()
    print("Train Acc")
    count = torch.tensor(0, device='cuda')
    loss_accum = torch.tensor(0.0, device='cuda')
    correct_count = torch.tensor(0, device='cuda')
    criterion = torch.nn.CrossEntropyLoss(reduction='sum')
    with torch.no_grad():
        for X, y in train_loader:
            X = X.cuda(non_blocking=True)
            y = y.cuda(non_blocking=True).flatten()
            logits = model(X)
            loss = criterion(logits, y)
            count += X.shape[0]
            loss_accum += loss.data
            correct_count += (torch.argmax(logits, dim=1) == y).sum()
        dist.all_reduce(loss_accum)
        dist.all_reduce(correct_count)
        dist.all_reduce(count)
        loss = (loss_accum / count).item()
        acc = (correct_count / count).item()
        print(f'{loss}')
        print(f'{acc:.4f}')
    return loss, acc


def test_model(model, val_loader: DataLoader):
    model = DDP(model)
    model.eval()
    print("Test")
    count = torch.tensor(0, device='cuda')
    loss_accum = torch.tensor(0.0, device='cuda')
    correct_count = torch.tensor(0, device='cuda')
    criterion = torch.nn.CrossEntropyLoss(reduction='sum')
    with torch.no_grad():
        for X, y in val_loader:
            X = X.cuda(non_blocking=True)
            y = y.cuda(non_blocking=True).flatten()
            logits = model(X)
            loss = criterion(logits, y)
            count += X.shape[0]
            loss_accum += loss.data
            correct_count += (torch.argmax(logits, dim=1) == y).sum()
        dist.all_reduce(loss_accum)
        dist.all_reduce(correct_count)
        dist.all_reduce(count)
        loss = (loss_accum / count).item()
        acc = (correct_count / count).item()
        print(f'{loss}')
        print(f'{acc:.4f}')
    return loss, acc


def test_model_orig(model, val_loader: DataLoader):
    model.eval()
    print("Test")
    count = torch.tensor(0, device='cuda')
    loss_accum = torch.tensor(0.0, device='cuda')
    correct_count = torch.tensor(0, device='cuda')
    criterion = torch.nn.CrossEntropyLoss(reduction='sum')
    with torch.no_grad():
        for X, y in val_loader:
            X = X.cuda(non_blocking=True)
            y = y.cuda(non_blocking=True).flatten()
            logits = model(X)
            loss = criterion(logits, y)
            count += X.shape[0]
            loss_accum += loss.data
            correct_count += (torch.argmax(logits, dim=1) == y).sum()
        loss = (loss_accum / count).item()
        acc = (correct_count / count).item()
        print(f'{loss}')
        print(f'{acc:.4f}')
    return loss, acc


def render_dep_graph(model):
    # render dependency graph
    dot = graphviz.Graph('dep-graph')
    marked = set()

    def make_dep_tree(model, layer):
        dot.node(layer, layer)
        marked.add(layer)
        for dep in model.dependent_layers(layer):
            dot.edge(dep, layer)
            if dep not in marked:
                make_dep_tree(model, dep)

    make_dep_tree(model, 'fc')
    dot.render('dep-graph-reactnet')


def run(rank, world_size, args):
    writer = SummaryWriter(log_dir=args.log_dir)
    logging_util.setup_logging(
        pathlib.Path(args.log_dir) / f'log_rank{rank}.txt',
        log_to_screen=(rank == 0)
    )
    print("===== Args =====")
    for name, value in vars(args).items():
        print(f"   {name}: {value}")
    print("================")

    print(f"Using GPU {rank}")
    torch.cuda.set_device(f"cuda:{rank}")

    # construct model
    model = load_model(args.load_path)
    # render_dep_graph(model)
    print(model)

    # load dataset
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    # data augmentation
    crop_scale = 0.08
    lighting_param = 0.1
    train_transforms = transforms.Compose([
        transforms.RandomResizedCrop(224, scale=(crop_scale, 1.0)),
        Lighting(lighting_param),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize])

    train_dataset = datasets.ImageFolder(
        traindir,
        transform=train_transforms)

    train_sampler = DistributedSampler(
        dataset=train_dataset,
        num_replicas=dist.get_world_size(),
        rank=dist.get_rank(),
        shuffle=True,
        drop_last=True
    )
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size // world_size, sampler=train_sampler,
        num_workers=args.workers, pin_memory=True)

    # load validation data
    val_dataset = datasets.ImageFolder(valdir, transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize,
    ]))

    val_sampler = DistributedSampler(
        dataset=val_dataset,
        num_replicas=dist.get_world_size(),
        rank=dist.get_rank(),
        shuffle=False,
        drop_last=False
    )
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=args.batch_size // world_size, sampler=val_sampler,
        num_workers=args.workers, pin_memory=True)

    # evaluated fp model
    train_acc_model(model, train_loader)
    test_model(model, val_loader)

    # quantize model
    quantize(model, train_loader, val_loader,
             block_iters=args.block_iters,
             xnor_net=args.xnor_net,
             skip_until=args.skip_until,
             finetune_epochs=args.finetune_epochs,
             block_size=args.block_size,
             sample_portion=args.sample_portion,
             findq_iters=args.findq_iters,
             log_dir=pathlib.Path(writer.log_dir),
             rank=rank, world_size=world_size)

    # evaluate quantized model
    train_acc_model(model, train_loader)
    test_model(model, val_loader)

    if rank == 0:
        torch.save(model.state_dict(), args.save_path)


def dist_main(rank, world_size, args):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    # initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    run(rank, world_size, args)
    dist.destroy_process_group()


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--data', default='/w14/dataset/ILSVRC2012/')
    parser.add_argument('--batch_size', type=int, default=512, help='batch size')
    parser.add_argument('-j', '--workers', default=12, type=int, metavar='N',
                        help='number of data loading workers (default: 4)')
    parser.add_argument('--xnor_net', default=False, action='store_true')
    parser.add_argument('--skip_until', default=None, type=str)
    parser.add_argument('--load_path', default='/d1/xxx/DBQ/imagenet-baseline-models/model_best.pth.tar', type=str)
    parser.add_argument('--block_iters', default=5, type=int)
    parser.add_argument('--finetune_epochs', default=10, type=int)
    parser.add_argument('--block_size', default=7, type=int)
    parser.add_argument('--sample_portion', default=1.0, type=float)
    parser.add_argument('--findq_iters', default=20, type=int)
    args = parser.parse_args()

    current_time = datetime.now().strftime('%b%d_%H-%M-%S-%f')
    args.log_dir = pathlib.Path('runs') / f"reactnet_{current_time}_{socket.gethostname()}"
    load_path = pathlib.Path(args.load_path)
    args.save_path = args.log_dir / f"{load_path.stem}-quantized.pt"
    print("Will save to ", args.save_path)

    world_size = torch.cuda.device_count()
    mp.set_start_method('forkserver')
    mp.spawn(dist_main, args=(world_size, args), nprocs=world_size, join=True)


if __name__ == '__main__':
    main()
