import sys
import copy
import math
import socket
import pathlib
import argparse
import itertools
from datetime import datetime
from functools import partial

import os
import torch
import graphviz
import numpy as np
from tqdm import tqdm
import networkx as nx
import more_itertools
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.utils.tensorboard import SummaryWriter
from networkx.algorithms.dag import topological_sort
from torch.utils.data import DataLoader, DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP

from reactnet import reactnet
sys.path.append('.')
from cifar_torch import cifar100
import logging_util

torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True


def make_graph(model):
    graph = nx.DiGraph()
    visited = set()

    def traverse(layer_name):
        visited.add(layer_name)
        for dep in model.dependent_layers(layer_name):
            graph.add_edge(dep, layer_name)
            if dep not in visited:
                traverse(dep)

    traverse('fc')
    return graph


def dataset_cov(train_loader, extract_patches, dim, sample_portion=1.0):
    count = 0
    mean = torch.zeros(dim, device='cuda', dtype=torch.float64)
    cov = torch.zeros(dim, dim, device='cuda', dtype=torch.float64)
    iters = math.ceil(len(train_loader) * sample_portion)
    for X, y in itertools.islice(tqdm(train_loader), iters):
        X = X.cuda(non_blocking=True)
        patches = extract_patches(X)

        other_count = patches.double().size(1)
        other_mean = patches.double().mean(1)
        other_cov = torch.cov(patches)

        if count == 0:
            count = other_count
            mean.copy_(other_mean)
            cov.copy_(other_cov)
            continue

        merged_count = count + other_count
        count_corr = (other_count * count) / merged_count

        flat_mean_diff = other_mean - mean
        mean += flat_mean_diff * other_count / merged_count

        mean_diffs = torch.broadcast_to(flat_mean_diff, cov.shape).T
        cov *= (count / merged_count)
        cov += (
                other_cov * (count / merged_count)
                + mean_diffs * mean_diffs.T * (count_corr / merged_count)
        )
    return cov.float()


@torch.jit.script
def stochastic_rounding(x, T: float):
    s = torch.sign(x)
    x *= T
    orig = torch.where(torch.abs(x) < 1, (-s * x ** 2 + 2 * x), s)
    p = s * orig.abs()
    pr = p * .5 + .5
    sampled = (torch.rand_like(x) < pr) * 2 - 1
    return sampled


class Sign(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x: torch.Tensor, T):
        ctx.save_for_backward(x)
        # stochastic rounding with temperature parameter (deterministic when T -> inf)
        return torch.sign(x)  # stochastic_rounding(x, T)

    @staticmethod
    def backward(ctx, g):
        x, = ctx.saved_tensors
        # TODO replace with gradient estimator in reactnet
        return g * torch.clamp(2 * (1 - torch.abs(x)), min=0.0), None


@torch.jit.script
def quantization_loss(q, M, w_hat, lamba: float = 1e-5):
    q_hat = M @ q
    loss = ((q_hat - w_hat) ** 2).sum(0)
    loss += lamba * q.norm(2, dim=0)
    return loss


@torch.enable_grad()
def find_q(w, M, nsteps=2000, nsamples=20, xnor_net=False):
    if xnor_net:
        alpha = w.abs().mean(0, keepdim=True)
        q = alpha * torch.sign(w)
        return q

    w_hat = M @ w

    alpha = torch.nn.Parameter(w.abs().mean(0, keepdim=True))
    w = torch.nn.Parameter(w.clone() / alpha)

    initial_alpha_scale = alpha.data.norm(2).clone()
    print(f"initial_alpha={initial_alpha_scale}")

    opt = torch.optim.SGD([w, alpha], lr=0.01, momentum=0.9, nesterov=True)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=nsteps)

    best = torch.tensor(10. ** 9, device='cuda')
    best_w = copy.deepcopy(w.data)
    best_alpha = copy.deepcopy(alpha.data)
    for _ in range(nsamples):
        for step in range(nsteps):
            opt.zero_grad()
            T = 1 + step / nsteps * 20  # Temperature 1 --> 20

            loss = quantization_loss(
                alpha * Sign.apply(w, T),
                M, w_hat
            )
            lsum = loss.mean()
            lsum.backward()

            # evaluate actual loss
            lsum = quantization_loss(
                alpha * torch.sign(w),
                M, w_hat
            ).mean()

            # update best
            cond = torch.le(lsum, best)
            best.data.copy_(torch.where(cond, lsum, best))
            best_w.data.copy_(torch.where(cond, w.data, best_w.data))
            best_alpha.data.copy_(torch.where(cond, alpha.data, best_alpha.data))

            opt.step()
            scheduler.step()
        w.data.add_(torch.randn_like(w) * 0.01)
        print("best=", best.item())

    best_list = [torch.zeros_like(best) for _ in range(dist.get_world_size())]
    dist.all_gather(best_list, best)
    best_idx = torch.argmin(torch.stack(best_list), dim=0).item()

    dist.broadcast(best_alpha, best_idx)
    dist.broadcast(best_w, best_idx)

    print(f"final best={best}, best_alpha={best_alpha.norm(2)}")
    return best_alpha, best_w


# knowledge distillation loss
@torch.jit.script
def distillation(logits, labels, teacher_scores, T: float, alpha: float):
    # distillation loss + classification loss
    # y: student
    # labels: hard label
    # teacher_scores: soft label
    task_loss = F.cross_entropy(logits, labels, reduction='none')
    teacher_loss = F.kl_div(
        F.log_softmax(logits / T, dim=-1),
        F.log_softmax(teacher_scores / T, dim=-1),
        reduction='none',
        log_target=True
    ).sum(1)
    return task_loss * (1 - alpha) + teacher_loss * (2 * T ** 2 + alpha)


@torch.enable_grad()
def train_model(model, opt, teacher, train_loader, n_steps,
                T=20.0, alpha=0.7):
    teacher = DDP(teacher)
    teacher.eval()
    model = DDP(model)
    model.train()
    steps = 0
    while True:
        count = torch.tensor(0, device='cuda')
        loss_accum = torch.tensor(0.0, device='cuda')
        correct_count = torch.tensor(0, device='cuda')
        for X, y in tqdm(train_loader, desc="Train"):
            X = X.cuda(non_blocking=True)
            y = y.cuda(non_blocking=True).flatten()
            model.zero_grad()

            logits = model(X)
            with torch.no_grad():
                teacher_scores = teacher(X)
            loss = distillation(logits, y, teacher_scores, T, alpha)
            loss.mean().backward()

            # apply EmpCov to the first layer
            #w = cur_w.data.view(cur_w.shape[0], -1).T  # (in, out)
            #cur_w.grad.add_((inv_sigma @ w).T.reshape_as(cur_w))

            opt.step()
            count += X.shape[0]
            loss_accum += loss.sum().data
            correct_count += (torch.argmax(logits, dim=1) == y).sum()

            steps += 1
            if steps > n_steps:
                return
        dist.all_reduce(loss_accum)
        dist.all_reduce(correct_count)
        dist.all_reduce(count)
        tqdm.write(f'{(loss_accum / count).item()}')
        accuracy = (correct_count / count).item()
        tqdm.write(f'{accuracy:.4f}')


@torch.no_grad()
def extract_patches(model, layer, X):
    X = model.inputs_for(layer, X)
    if isinstance(layer, torch.nn.Conv2d):
        X = F.unfold(
            X,
            kernel_size=layer.kernel_size,
            dilation=layer.dilation,
            padding=layer.padding,
            stride=layer.stride,
        )  # (batch, ch*ks*ks, height*width)
        X = X.permute(1, 0, 2)  # (ch*ks*ks, batch, height*width)
        X = X.reshape(X.shape[0], -1)  # (dim, batch*height*width)
    return X


@torch.no_grad()
def quantize(model, train_loader, val_loader, block_iters, finetune_epochs, world_size,
             xnor_net=False, skip_until=None,
             block_size=7, sample_portion=1.0, log_dir=None, findq_iters=20, rank=0):
    graph = make_graph(model)
    modules = dict(model.named_modules())
    weight_params = sum([list(m.parameters(recurse=False)) for m in model.modules()
                         if not isinstance(m, torch.nn.BatchNorm2d)], [])
    bn_params = sum([list(m.parameters(recurse=False)) for m in model.modules()
                     if isinstance(m, torch.nn.BatchNorm2d)], [])
    print(f"{len(weight_params)} weight params, {len(bn_params)} bn params")

    teacher = copy.deepcopy(model)
    total_steps = len(train_loader) * finetune_epochs
    n_steps = total_steps // block_iters

    # for layer_name in topological_sort(graph):
    layers = list(topological_sort(graph))
    layer_groups = list(more_itertools.chunked_even(layers, block_size))
    print("Layer groups", layer_groups)

    for group_idx, layer_group in enumerate(layer_groups):

        opt = torch.optim.Adam(
            [{'params': bn_params},
             {'params': weight_params, 'weight_decay': 1e-5}],
            lr=1.25e-3
        )

        last_weights = dict()
        Vs = dict()
        for it in range(block_iters):
            print(f"Iteration {it}.")

            if it > 0 or group_idx > 0:
                print("Block initial fine-tuning")
                train_model(model, opt, teacher, train_loader, n_steps=total_steps)
                test_model(model, val_loader)

            new_frozen = []
            for layer_idx, layer_name in enumerate(layer_group):
                if layer_name == skip_until:
                    skip_until = None

                # skip last layer for now
                if layer_name == 'fc':
                    continue

                print("Quantizing", layer_name)
                if layer_name.endswith('_down'):
                    layer = (modules[layer_name + '1'], modules[layer_name + '2'])
                    cur_params = [layer[0].weight, layer[1].weight]
                    groups = layer[0].groups
                    get_patches = partial(extract_patches, model, modules[layer_name + '1'])

                    # calculate dimensions
                    w = torch.cat([layer[0].weight.data, layer[1].weight.data], dim=0)
                    dim = np.prod(w.shape[1:])
                    w = w.view(w.size(0), dim)  # (out, in)

                else:
                    layer = modules[layer_name]
                    cur_params = [layer.weight]
                    groups = layer.groups
                    get_patches = partial(extract_patches, model, layer)

                    # calculate dimensions
                    w = layer.weight.data.clone()
                    dim = np.prod(w.shape[1:])
                    w = w.view(w.size(0), dim)  # (out, in)

                # freeze layer
                for p in cur_params:
                    new_frozen.append(p)
                    p.requires_grad = False
                    p.grad = None

                # FIXME skip depthwise layer for now
                if groups > 1:
                    continue
                # skip quantization until specified layer is reached
                if skip_until is not None:
                    continue

                # calculate input covariance
                model.eval()
                cov = dataset_cov(train_loader, get_patches, dim, sample_portion)
                _, s, V = torch.svd(cov)
                S = torch.diag(s ** 0.4)
                if layer_name in Vs:
                    V = Vs[layer_name]
                M = S @ V.t()

                # remember weights
                prev_w = last_weights.get(layer_name)
                last_weights[layer_name] = w.clone()

                # quantize layer
                print("Finding quantized values")
                alpha, w_prime = find_q(
                    w.T, M,
                    xnor_net=xnor_net,
                    nsamples=findq_iters // block_iters
                )
                q = (alpha * torch.sign(w_prime)).t()

                # find V
                if prev_w is not None:
                    Vs[layer_name] = find_V(prev_w, q.T, S, V)

                # apply quantized layer weight
                if layer_name.endswith('_down'):
                    layer[0].weight.data.copy_(q[:q.shape[0]//2].reshape(layer[0].weight.shape))
                    layer[1].weight.data.copy_(q[q.shape[0]//2:].reshape(layer[1].weight.shape))
                else:
                    layer.weight.data.copy_(q.reshape(layer.weight.shape))

                test_model(model, val_loader)

                print("Fine-tuning")
                train_model(model, opt, teacher, train_loader, n_steps=n_steps)
                test_model(model, val_loader)

            print("Block final fine-tuning")
            train_model(model, opt, teacher, train_loader, n_steps=total_steps)
            test_model(model, val_loader)

            if rank == 0:
                print("Saving to ", log_dir / f"block-{group_idx}.pth")
                torch.save(model.state_dict(), log_dir / f"block-{group_idx}.pth")

            # unfreeze block layers
            if it < block_iters - 1:
                print(f"Unfreezing {len(new_frozen)} layers")
                for p in new_frozen:
                    p.requires_grad = True
            else:
                weight_params = [p for p in weight_params if p.requires_grad]


@torch.enable_grad()
def find_V(prev_w, q, S, V):
    print("Find V")
    Vt = torch.nn.Parameter(V.T.clone())
    opt = torch.optim.SGD([Vt], lr=0.02, momentum=0.9, nesterov=True)
    for it in range(1000):
        opt.zero_grad()

        M = S @ Vt
        loss = quantization_loss(q, M, M @ prev_w.T).mean()

        VVt = (Vt.T @ Vt)
        VVt.diagonal().sub_(1.0)
        reg = VVt.norm(2)**2  # squared frobenius norm
        loss += reg * 0.1

        loss.backward()
        if it % 100 == 0:
            print(loss.item())

        opt.step()
    return Vt.data.T


def load_model(path):
    model = reactnet(num_classes=100).cuda()
    checkpoint = torch.load(path, map_location='cpu')
    print(checkpoint['epoch'], checkpoint['best_top1_acc'])
    state_dict = {k.replace('module.', ''): v for k, v in checkpoint['state_dict'].items()}
    model.load_state_dict(state_dict, strict=False)
    return model


def train_acc_model(model, train_loader: DataLoader):
    model = DDP(model)
    model.train()
    print("Train Acc")
    count = torch.tensor(0, device='cuda')
    loss_accum = torch.tensor(0.0, device='cuda')
    correct_count = torch.tensor(0, device='cuda')
    criterion = torch.nn.CrossEntropyLoss(reduction='sum')
    with torch.no_grad():
        for X, y in train_loader:
            X = X.cuda(non_blocking=True)
            y = y.cuda(non_blocking=True).flatten()
            logits = model(X)
            loss = criterion(logits, y)
            count += X.shape[0]
            loss_accum += loss.data
            correct_count += (torch.argmax(logits, dim=1) == y).sum()
        dist.all_reduce(loss_accum)
        dist.all_reduce(correct_count)
        dist.all_reduce(count)
        loss = (loss_accum / count).item()
        acc = (correct_count / count).item()
        print(f'{loss}')
        print(f'{acc:.4f}')
    return loss, acc


def test_model(model, val_loader: DataLoader):
    model = DDP(model)
    model.eval()
    print("Test")
    count = torch.tensor(0, device='cuda')
    loss_accum = torch.tensor(0.0, device='cuda')
    correct_count = torch.tensor(0, device='cuda')
    criterion = torch.nn.CrossEntropyLoss(reduction='sum')
    with torch.no_grad():
        for X, y in val_loader:
            X = X.cuda(non_blocking=True)
            y = y.cuda(non_blocking=True).flatten()
            logits = model(X)
            loss = criterion(logits, y)
            count += X.shape[0]
            loss_accum += loss.data
            correct_count += (torch.argmax(logits, dim=1) == y).sum()
        dist.all_reduce(loss_accum)
        dist.all_reduce(correct_count)
        dist.all_reduce(count)
        loss = (loss_accum / count).item()
        acc = (correct_count / count).item()
        print(f'{loss}')
        print(f'{acc:.4f}')
    return loss, acc


def test_model_orig(model, val_loader: DataLoader):
    model.eval()
    print("Test")
    count = torch.tensor(0, device='cuda')
    loss_accum = torch.tensor(0.0, device='cuda')
    correct_count = torch.tensor(0, device='cuda')
    criterion = torch.nn.CrossEntropyLoss(reduction='sum')
    with torch.no_grad():
        for X, y in val_loader:
            X = X.cuda(non_blocking=True)
            y = y.cuda(non_blocking=True).flatten()
            logits = model(X)
            loss = criterion(logits, y)
            count += X.shape[0]
            loss_accum += loss.data
            correct_count += (torch.argmax(logits, dim=1) == y).sum()
        loss = (loss_accum / count).item()
        acc = (correct_count / count).item()
        print(f'{loss}')
        print(f'{acc:.4f}')
    return loss, acc


def render_dep_graph(model):
    # render dependency graph
    dot = graphviz.Graph('dep-graph')
    marked = set()

    def make_dep_tree(model, layer):
        dot.node(layer, layer)
        marked.add(layer)
        for dep in model.dependent_layers(layer):
            dot.edge(dep, layer)
            if dep not in marked:
                make_dep_tree(model, dep)

    make_dep_tree(model, 'fc')
    dot.render('dep-graph-reactnet')


def run(rank, world_size, args):
    writer = SummaryWriter(log_dir=args.log_dir)
    logging_util.setup_logging(
        pathlib.Path(args.log_dir) / f'log_rank{rank}.txt',
        log_to_screen=(rank == 0)
    )
    print("===== Args =====")
    for name, value in vars(args).items():
        print(f"   {name}: {value}")
    print("================")

    print(f"Using GPU {rank}")
    torch.cuda.set_device(f"cuda:{rank}")

    # construct model
    model = load_model(args.load_path)
    # render_dep_graph(model)
    print(model)

    # load dataset
    train_loader, val_loader = cifar100(batch_size=200, workers=4, distributed=True)

    # evaluated fp model
    train_acc_model(model, train_loader)
    test_model(model, val_loader)

    # quantize model
    quantize(model, train_loader, val_loader,
             block_iters=args.block_iters,
             xnor_net=args.xnor_net,
             skip_until=args.skip_until,
             finetune_epochs=args.finetune_epochs,
             block_size=args.block_size,
             sample_portion=args.sample_portion,
             findq_iters=args.findq_iters,
             log_dir=pathlib.Path(writer.log_dir),
             rank=rank, world_size=world_size)

    # evaluate quantized model
    train_acc_model(model, train_loader)
    test_model(model, val_loader)

    if rank == 0:
        torch.save(model.state_dict(), args.save_path)


def dist_main(rank, world_size, args):
    os.environ['MASTER_ADDR'] = 'localhost'
    if 'MASTER_PORT' not in os.environ:
        os.environ['MASTER_PORT'] = '12355'

    # initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    run(rank, world_size, args)
    dist.destroy_process_group()


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--xnor_net', default=False, action='store_true')
    parser.add_argument('--skip_until', default=None, type=str)
    parser.add_argument('--load_path', default='/d1/xxx/DBQ/baseline/1_step1/models_run1/model_best.pth.tar', type=str)
    parser.add_argument('--block_iters', default=5, type=int)
    parser.add_argument('--finetune_epochs', default=10, type=int)
    parser.add_argument('--block_size', default=7, type=int)
    parser.add_argument('--sample_portion', default=1.0, type=float)
    parser.add_argument('--findq_iters', default=20, type=int)
    args = parser.parse_args()

    current_time = datetime.now().strftime('%b%d_%H-%M-%S-%f')
    args.log_dir = pathlib.Path('runs') / f"reactnet_{current_time}_{socket.gethostname()}"
    load_path = pathlib.Path(args.load_path)
    args.save_path = args.log_dir / f"{load_path.stem}-quantized.pt"
    print("Will save to ", args.save_path)

    world_size = torch.cuda.device_count()
    mp.set_start_method('forkserver')
    mp.spawn(dist_main, args=(world_size, args), nprocs=world_size, join=True)


if __name__ == '__main__':
    main()
