import sys
import copy
import glob
import math
import socket
import shutil
import pathlib
import argparse
import itertools
from datetime import datetime
from functools import partial

import os
import torch
import graphviz
import numpy as np
from tqdm import tqdm
import networkx as nx
import more_itertools
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.utils.tensorboard import SummaryWriter
from networkx.algorithms.dag import topological_sort
from torch.utils.data import DataLoader, DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP

from reactnet import reactnet
sys.path.append('.')
from cifar_torch import cifar100
import logging_util

torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True


def make_graph(model):
    graph = nx.DiGraph()
    visited = set()

    def traverse(layer_name):
        visited.add(layer_name)
        for dep in model.dependent_layers(layer_name):
            graph.add_edge(dep, layer_name)
            if dep not in visited:
                traverse(dep)

    traverse('fc')
    return graph


def dataset_cov(train_loader, extract_patches, dim, sample_portion=1.0):
    count = 0
    mean = torch.zeros(dim, device='cuda', dtype=torch.float64)
    cov = torch.zeros(dim, dim, device='cuda', dtype=torch.float64)
    iters = math.ceil(len(train_loader) * sample_portion)
    for X, y in itertools.islice(tqdm(train_loader), iters):
        X = X.cuda(non_blocking=True)
        patches = extract_patches(X)

        other_count = patches.double().size(1)
        other_mean = patches.double().mean(1)
        other_cov = torch.cov(patches)

        if count == 0:
            count = other_count
            mean.copy_(other_mean)
            cov.copy_(other_cov)
            continue

        merged_count = count + other_count
        count_corr = (other_count * count) / merged_count

        flat_mean_diff = other_mean - mean
        mean += flat_mean_diff * other_count / merged_count

        mean_diffs = torch.broadcast_to(flat_mean_diff, cov.shape).T
        cov *= (count / merged_count)
        cov += (
                other_cov * (count / merged_count)
                + mean_diffs * mean_diffs.T * (count_corr / merged_count)
        )
    return cov.float()


class Sign(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x: torch.Tensor):
        ctx.save_for_backward(x)
        return torch.sign(x)

    @staticmethod
    def backward(ctx, g):
        x, = ctx.saved_tensors
        return g * torch.clamp(2 * (1 - torch.abs(x)), min=0.0)


@torch.jit.script
def quantization_loss(q, M, w_hat, lamba: float = 1e-5):
    q_hat = M @ q
    loss = ((q_hat - w_hat) ** 2).sum(0)
    loss += lamba * q.norm(2, dim=0)
    return loss


# knowledge distillation loss
@torch.jit.script
def distillation(logits, labels, teacher_scores, T: float, alpha: float):
    # distillation loss + classification loss
    # y: student
    # labels: hard label
    # teacher_scores: soft label
    task_loss = F.cross_entropy(logits, labels, reduction='none')
    teacher_loss = F.kl_div(
        F.log_softmax(logits / T, dim=-1),
        F.log_softmax(teacher_scores / T, dim=-1),
        reduction='none',
        log_target=True
    ).sum(1)
    return task_loss, task_loss * (1 - alpha) + teacher_loss * (2 * T ** 2 + alpha)


@torch.jit.script
def fused_clip(max_norm: float, total_norm):
    clip_coef = max_norm / (total_norm + 1e-6)
    clip_coef = torch.clip(clip_coef, 0.0, 1.0)
    return clip_coef


@torch.no_grad()
def clip_grad_norm_(
        parameters, max_norm: float, norm_type: float = 2.0) -> torch.Tensor:
    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]
    parameters = [p for p in parameters if p.grad is not None]
    if len(parameters) == 0:
        return torch.tensor(0.)
    if norm_type == float('inf'):
        norms = [p.grad.detach().abs().max() for p in parameters]
        total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms))
    else:
        total_norm = torch.norm(torch.stack([
            torch.norm(p.grad.detach(), norm_type) for p in parameters]), norm_type)
    clip_coef = fused_clip(max_norm, total_norm)
    for p in parameters:
        p.grad.detach().mul_(clip_coef)
    return total_norm


@torch.enable_grad()
def train_model(model, opt, teacher, train_loader, n_steps,
                T=20.0, alpha=0.7):
    teacher = DDP(teacher)
    teacher.eval()
    model = DDP(model)
    model.train()
    steps = 0
    orig_lr = opt.param_groups[0]['lr']
    params = sum([g['params'] for g in opt.param_groups], [])

    while True:
        count = torch.tensor(0, device='cuda')
        task_loss_accum = torch.tensor(0.0, device='cuda')
        loss_accum = torch.tensor(0.0, device='cuda')
        correct_count = torch.tensor(0, device='cuda')
        for X, y in tqdm(train_loader, desc="Train"):
            X = X.cuda(non_blocking=True)
            y = y.cuda(non_blocking=True).flatten()
            model.zero_grad()

            logits = model(X)
            with torch.no_grad():
                teacher_scores = teacher(X)
            task_loss, loss = distillation(logits, y, teacher_scores, T, alpha)
            loss.mean().backward()

            # apply EmpCov to the first layer
            #w = cur_w.data.view(cur_w.shape[0], -1).T  # (in, out)
            #cur_w.grad.add_((inv_sigma @ w).T.reshape_as(cur_w))

            clip_grad_norm_(params, 2.0)  # clip gradient
            for g in opt.param_groups:
                r = steps / n_steps
                g['lr'] = orig_lr * min(r, 1 - r) * 2

            opt.step()
            count += X.shape[0]
            task_loss_accum += task_loss.sum().data
            loss_accum += loss.sum().data
            correct_count += (torch.argmax(logits, dim=1) == y).sum()

            steps += 1
            if steps > n_steps:
                return
        dist.all_reduce(task_loss_accum)
        dist.all_reduce(loss_accum)
        dist.all_reduce(correct_count)
        dist.all_reduce(count)
        tqdm.write(f'task{(task_loss_accum / count).item()} / dist{(loss_accum / count).item()}')
        accuracy = (correct_count / count).item()
        tqdm.write(f'{accuracy:.4f}')


class QuantizationModel(torch.nn.Module):
    def __init__(self, model, log_s, V):
        super(QuantizationModel, self).__init__()
        self.model = model
        self.V = torch.nn.Parameter(V.clone())
        self.log_s = torch.nn.Parameter(log_s.clone())

    def forward(self, x):
        return self.model(x)


@torch.enable_grad()
def quantize_layer(model, teacher, train_loader, n_steps,
                   bn_params, cur_layers, log_s, V, distance,
                   add_distance=True,
                   T=20.0, distill_alpha=0.7):

    # freeze all layers
    original_requires_grad = dict()
    for module_name, module in model.named_modules():
        if isinstance(module, torch.nn.BatchNorm2d):
            continue
        for name, p in module.named_parameters(recurse=False):
            original_requires_grad[module_name + '.' + name] = p.requires_grad
            p.requires_grad = False
            p.grad = None

    flat_layers_list = []
    orig_weight_params = []
    cur_params = []

    ws = []
    for layers in cur_layers:
        w = torch.cat([layer.weight.data for layer in layers], dim=0)
        w = w.view(w.shape[0], -1)  # (output, input)

        if w.shape[1] > 256:
            stride = w.shape[1] // 256
            stride = next_power_of_2(stride)
            assert w.shape[1] % stride == 0
            w = F.avg_pool1d(w[None], stride, stride)[0]

        ws.append(w)
        for layer in layers:
            flat_layers_list.append(layer)
            orig_weight_params.append(layer.set_mode('quantize'))
            cur_params.extend([layer.alpha, layer.scores])

    for i in range(len(ws)):
        combined_w = ws[i]
        for w in ws[i + 1:]:

            pass
        ws[i] = combined_w

    det = log_s.sum()

    orig_model = model
    if add_distance:
        model = QuantizationModel(model, log_s, V)
        log_s, V = model.log_s, model.V

    opt = torch.optim.Adam(
        [{'params': bn_params},
         {'params': cur_params, 'weight_decay': 1e-5},
         {'params': [log_s, V]}],
        lr=1.25e-3 / 8
    )

    teacher = DDP(teacher)
    teacher.eval()
    model = DDP(model)
    model.train()
    steps = 0
    while True:
        count = torch.tensor(0, device='cuda')
        dist_loss_accum = torch.tensor(0.0, device='cuda')
        task_loss_accum = torch.tensor(0.0, device='cuda')
        loss_accum = torch.tensor(0.0, device='cuda')
        correct_count = torch.tensor(0, device='cuda')
        for X, y in tqdm(train_loader, desc="Train"):
            X = X.cuda(non_blocking=True)
            y = y.cuda(non_blocking=True).flatten()
            model.zero_grad()

            logits = model(X)
            with torch.no_grad():
                teacher_scores = teacher(X)
            task_loss, loss = distillation(logits, y, teacher_scores, T, distill_alpha)
            dist_loss = loss

            # add distance metric
            if add_distance:
                wqs = []
                for layers in cur_layers:
                    alpha = torch.cat([l.alpha for l in layers], dim=0)
                    scores = torch.cat([l.scores for l in layers], dim=0)
                    wq = alpha * scores
                    wq = wq.view(wq.shape[0], -1)  # (out, in)

                    if wq.shape[1] > 256:
                        stride = wq.shape[1] // 256
                        stride = next_power_of_2(stride)
                        assert wq.shape[1] % stride == 0
                        wq = F.avg_pool1d(wq[None], stride, stride)[0]

                    wqs.append(wq)
                wqs = padded_stack(wqs)
                dist_value = distance(log_s, V, wqs, ws)
                loss = loss + dist_value

                # add regularization
                VVt = (V @ V.T)
                VVt.diagonal().sub_(1.0)
                ortho_reg = VVt.norm(2) ** 2  # squared frobenius norm
                det_reg = ((log_s.sum() - det) / det)**2
                reg = ortho_reg * 0.1 + det_reg
                loss += reg * 10
                tqdm.write(f"ortho_reg={ortho_reg.item()}, det_reg={det_reg.item()}")

            # backward pass
            loss.mean().backward()

            # optimizer step
            opt.step()
            count += X.shape[0]
            task_loss_accum += task_loss.sum().data
            dist_loss_accum += dist_loss.sum().data
            loss_accum += loss.sum().data
            correct_count += (torch.argmax(logits, dim=1) == y).sum()

            steps += 1
            if steps > n_steps:
                break
        if steps > n_steps:
            break

        dist.all_reduce(task_loss_accum)
        dist.all_reduce(dist_loss_accum)
        dist.all_reduce(loss_accum)
        dist.all_reduce(correct_count)
        dist.all_reduce(count)

        tqdm.write(f'task{(task_loss_accum / count).item()} / '
                   f'dist{(dist_loss_accum / count).item()} / '
                   f'reg{(loss_accum / count).item()}')
        accuracy = (correct_count / count).item()
        tqdm.write(f'{accuracy:.4f}')

    # restore layer mode
    for layer, param in zip(flat_layers_list, orig_weight_params):
        layer.set_mode('original', param)

    # restore frozen state before
    params = dict(orig_model.named_parameters())
    for name, requires_grad in original_requires_grad.items():
        params[name].requires_grad = requires_grad
        if not requires_grad:
            params[name].grad = None

    return log_s, V


def weight_distance(log_s, V, wq, w):
    M = torch.diag(torch.exp(log_s)) @ V.T
    w_hat = M @ w.T
    loss = quantization_loss(wq.T, M, w_hat)
    lsum = loss.mean()
    return lsum


def expand_S_V(old_log_s: torch.Tensor, old_V: torch.Tensor,
               new_log_s: torch.Tensor, new_V: torch.Tensor):
    dim = new_V.shape[0]
    old_dim = old_V.shape[0]
    new_dim = old_dim + dim

    large_log_s = old_log_s.new_zeros((new_dim,))
    large_log_s[:old_dim] = old_log_s
    large_log_s[-dim:] = new_log_s

    large_V = old_V.new_zeros((new_dim, new_dim))
    large_V[:old_dim, :old_dim] = old_V
    large_V[-dim:, -dim:] = new_V

    return large_log_s, large_V


def eigh(cov):
    try:
        return torch.linalg.eigh(cov)
    except torch._C._LinAlgError:
        return torch.linalg.eigh(cov)


def next_power_of_2(x):
    return 1 if x == 0 else 2**math.ceil(math.log2(x))


@torch.no_grad()
def quantize(model, train_loader, val_loader, block_iters, finetune_epochs, world_size,
             xnor_net=False, no_distance=False, skip_until=None,
             block_size=7, sample_portion=1.0, log_dir=None, findq_iters=20, rank=0):
    graph = make_graph(model)
    modules = dict(model.named_modules())
    weight_params = sum([list(m.parameters(recurse=False)) for m in model.modules()
                         if not isinstance(m, torch.nn.BatchNorm2d)], [])
    bn_params = sum([list(m.parameters(recurse=False)) for m in model.modules()
                     if isinstance(m, torch.nn.BatchNorm2d)], [])
    print(f"{len(weight_params)} weight params, {len(bn_params)} bn params")

    teacher = copy.deepcopy(model)
    total_steps = len(train_loader) * finetune_epochs
    n_steps = total_steps // block_iters

    # for layer_name in topological_sort(graph):
    layers = list(topological_sort(graph))
    layer_groups = list(more_itertools.chunked_even(layers, block_size))
    print("Layer groups", layer_groups)

    finetune_lr = 1.25e-3 / 8
    weight_decay = 1e-5

    for group_idx, layer_group in enumerate(layer_groups):
        for it in range(block_iters):
            print(f"Iteration {it}.")

            if it > 0 or group_idx > 0:
                print("Block initial fine-tuning")

                opt = torch.optim.Adam(
                    [{'params': bn_params},
                     {'params': weight_params, 'weight_decay': weight_decay}],
                    lr=finetune_lr
                )

                train_model(model, opt, teacher, train_loader, n_steps=total_steps)
                test_model(model, val_loader)

            V = torch.zeros((0, 0), device='cuda')
            log_s = torch.zeros((0,), device='cuda')

            new_frozen = []
            for layer_idx, layer_name in enumerate(layer_group):
                if layer_name == skip_until:
                    skip_until = None

                # skip last layer for now
                if layer_name == 'fc':
                    continue

                print("Quantizing", layer_name)
                if layer_name.endswith('_down'):
                    cur_layers = [modules[layer_name + '1'], modules[layer_name + '2']]
                else:
                    cur_layers = [modules[layer_name]]
                cur_params = [layer.weight for layer in cur_layers]
                groups = cur_layers[0].groups
                get_patches = partial(extract_patches, model, cur_layers[0])
                dim = np.prod(cur_layers[0].weight.shape[1:])

                # FIXME skip depthwise layer for now
                if groups > 1:
                    continue
                # skip quantization until specified layer is reached
                if skip_until is not None:
                    continue

                # calculate input covariance
                model.eval()
                cov = dataset_cov(train_loader, get_patches, dim, sample_portion)

                if cov.shape[1] > 256:
                    stride = cov.shape[1] // 256
                    print(f"stride = {stride}")
                    stride = next_power_of_2(stride)
                    assert cov.shape[1] % stride == 0
                    cov = F.avg_pool2d(cov[None, None], stride, stride)[0, 0]

                s, new_V = eigh(cov)
                new_log_s = torch.log(s + 1e-6) * 0.5

                # TODO: alternatively try pooling on the log_s and V

                # expand covariance matrix
                log_s, V = expand_S_V(log_s, V, new_log_s, new_V)

                # freeze layer
                new_frozen.append(cur_layers)
                for p in cur_params:
                    p.requires_grad = False
                    p.grad = None

            # quantize layer
            print("Finding quantized values")
            log_s, V = quantize_layer(
                model, teacher, train_loader, n_steps,
                bn_params=bn_params, cur_layers=new_frozen,
                distance=weight_distance, log_s=log_s, V=V,
                add_distance=(not no_distance)
            )

            test_model(model, val_loader)

            print("Block final fine-tuning")

            opt = torch.optim.Adam(
                [{'params': bn_params},
                 {'params': weight_params, 'weight_decay': weight_decay}],
                lr=finetune_lr
            )

            train_model(model, opt, teacher, train_loader, n_steps=total_steps)
            test_model(model, val_loader)

            if rank == 0:
                print("Saving to ", log_dir / f"block-{group_idx}.pth")
                torch.save(model.state_dict(), log_dir / f"block-{group_idx}.pth")

            # unfreeze block layers
            if it < block_iters - 1:
                print(f"Unfreezing {len(new_frozen)} layers")
                for layers in new_frozen:
                    for layer in layers:
                        layer.weight.requires_grad = True
            else:
                weight_params = [p for p in weight_params if p.requires_grad]

            # update teacher
            # teacher.load_state_dict(copy.deepcopy(model.state_dict()))


@torch.no_grad()
def extract_patches(model, layer, X):
    X = model.inputs_for(layer, X)
    if isinstance(layer, torch.nn.Conv2d):
        X = F.unfold(
            X,
            kernel_size=layer.kernel_size,
            dilation=layer.dilation,
            padding=layer.padding,
            stride=layer.stride,
        )  # (batch, ch*ks*ks, height*width)
        X = X.permute(1, 0, 2)  # (ch*ks*ks, batch, height*width)
        X = X.reshape(X.shape[0], -1)  # (dim, batch*height*width)
    return X


def load_model(path):
    model = reactnet(num_classes=100).cuda()
    checkpoint = torch.load(path, map_location='cpu')
    print(checkpoint['epoch'], checkpoint['best_top1_acc'])
    state_dict = {k.replace('module.', ''): v for k, v in checkpoint['state_dict'].items()}
    model.load_state_dict(state_dict, strict=False)
    return model


def train_acc_model(model, train_loader: DataLoader):
    model = DDP(model)
    model.train()
    print("Train Acc")
    count = torch.tensor(0, device='cuda')
    loss_accum = torch.tensor(0.0, device='cuda')
    correct_count = torch.tensor(0, device='cuda')
    criterion = torch.nn.CrossEntropyLoss(reduction='sum')
    with torch.no_grad():
        for X, y in train_loader:
            X = X.cuda(non_blocking=True)
            y = y.cuda(non_blocking=True).flatten()
            logits = model(X)
            loss = criterion(logits, y)
            count += X.shape[0]
            loss_accum += loss.data
            correct_count += (torch.argmax(logits, dim=1) == y).sum()
        dist.all_reduce(loss_accum)
        dist.all_reduce(correct_count)
        dist.all_reduce(count)
        loss = (loss_accum / count).item()
        acc = (correct_count / count).item()
        print(f'{loss}')
        print(f'{acc:.4f}')
    return loss, acc


def test_model(model, val_loader: DataLoader):
    model = DDP(model)
    model.eval()
    print("Test")
    count = torch.tensor(0, device='cuda')
    loss_accum = torch.tensor(0.0, device='cuda')
    correct_count = torch.tensor(0, device='cuda')
    criterion = torch.nn.CrossEntropyLoss(reduction='sum')
    with torch.no_grad():
        for X, y in val_loader:
            X = X.cuda(non_blocking=True)
            y = y.cuda(non_blocking=True).flatten()
            logits = model(X)
            loss = criterion(logits, y)
            count += X.shape[0]
            loss_accum += loss.data
            correct_count += (torch.argmax(logits, dim=1) == y).sum()
        dist.all_reduce(loss_accum)
        dist.all_reduce(correct_count)
        dist.all_reduce(count)
        loss = (loss_accum / count).item()
        acc = (correct_count / count).item()
        print(f'{loss}')
        print(f'{acc:.4f}')
    return loss, acc


def test_model_orig(model, val_loader: DataLoader):
    model.eval()
    print("Test")
    count = torch.tensor(0, device='cuda')
    loss_accum = torch.tensor(0.0, device='cuda')
    correct_count = torch.tensor(0, device='cuda')
    criterion = torch.nn.CrossEntropyLoss(reduction='sum')
    with torch.no_grad():
        for X, y in val_loader:
            X = X.cuda(non_blocking=True)
            y = y.cuda(non_blocking=True).flatten()
            logits = model(X)
            loss = criterion(logits, y)
            count += X.shape[0]
            loss_accum += loss.data
            correct_count += (torch.argmax(logits, dim=1) == y).sum()
        loss = (loss_accum / count).item()
        acc = (correct_count / count).item()
        print(f'{loss}')
        print(f'{acc:.4f}')
    return loss, acc


def render_dep_graph(model):
    # render dependency graph
    dot = graphviz.Graph('dep-graph')
    marked = set()

    def make_dep_tree(model, layer):
        dot.node(layer, layer)
        marked.add(layer)
        for dep in model.dependent_layers(layer):
            dot.edge(dep, layer)
            if dep not in marked:
                make_dep_tree(model, dep)

    make_dep_tree(model, 'fc')
    dot.render('dep-graph-reactnet')


def run(rank, world_size, args):
    writer = SummaryWriter(log_dir=args.log_dir)
    logging_util.setup_logging(
        pathlib.Path(args.log_dir) / f'log_rank{rank}.txt',
        log_to_screen=(rank == 0)
    )
    print("===== Args =====")
    for name, value in vars(args).items():
        print(f"   {name}: {value}")
    print("================")

    print(f"Using GPU {rank}")
    torch.cuda.set_device(f"cuda:{rank}")

    # construct model
    model = load_model(args.load_path)
    # render_dep_graph(model)
    print(model)

    # load dataset
    train_loader, val_loader = cifar100(batch_size=200, workers=4, distributed=True)

    # evaluated fp model
    train_acc_model(model, train_loader)
    test_model(model, val_loader)

    # quantize model
    quantize(model, train_loader, val_loader,
             block_iters=args.block_iters,
             xnor_net=args.xnor_net,
             skip_until=args.skip_until,
             finetune_epochs=args.finetune_epochs,
             block_size=args.block_size,
             sample_portion=args.sample_portion,
             findq_iters=args.findq_iters,
             no_distance=args.no_distance,
             log_dir=pathlib.Path(writer.log_dir),
             rank=rank, world_size=world_size)

    # evaluate quantized model
    train_acc_model(model, train_loader)
    test_model(model, val_loader)

    if rank == 0:
        torch.save(model.state_dict(), args.save_path)


def dist_main(rank, world_size, args):
    os.environ['MASTER_ADDR'] = 'localhost'
    if 'MASTER_PORT' not in os.environ:
        os.environ['MASTER_PORT'] = '12355'

    # initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    run(rank, world_size, args)
    dist.destroy_process_group()


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--xnor_net', default=False, action='store_true')
    parser.add_argument('--skip_until', default=None, type=str)
    parser.add_argument('--load_path', default='/d1/xxx/DBQ/baseline/1_step1/models_run1/model_best.pth.tar', type=str)
    parser.add_argument('--block_iters', default=1, type=int)
    parser.add_argument('--finetune_epochs', default=40, type=int)
    parser.add_argument('--block_size', default=4, type=int)
    parser.add_argument('--sample_portion', default=1.0, type=float)
    parser.add_argument('--findq_iters', default=20, type=int)
    parser.add_argument('--no_distance', default=False, action='store_true')
    args = parser.parse_args()

    current_time = datetime.now().strftime('%b%d_%H-%M-%S-%f')
    args.log_dir = pathlib.Path('runs') / f"reactnet_{current_time}_{socket.gethostname()}"
    load_path = pathlib.Path(args.load_path)
    args.save_path = args.log_dir / f"{load_path.stem}-quantized.pt"
    print("Will save to ", args.save_path)

    source_copy_dir = args.log_dir / 'source'
    source_copy_dir.mkdir(parents=True, exist_ok=False)
    current_dir = pathlib.Path(__file__).parent
    py_files = glob.glob(str(current_dir / '*.py'))  # + glob.glob(str(current_dir / '**' / '*.py'))
    for f in py_files:
        relpath = pathlib.Path(f).relative_to(current_dir)
        print(f"Copy {relpath} to {source_copy_dir}")
        target = source_copy_dir / relpath
        target.parent.mkdir(parents=True, exist_ok=True)
        shutil.copy(f, target)

    world_size = torch.cuda.device_count()
    mp.set_start_method('forkserver')
    mp.spawn(dist_main, args=(world_size, args), nprocs=world_size, join=True)


if __name__ == '__main__':
    main()
