import sys
import copy
import glob
import math
import socket
import shutil
import pathlib
import argparse
import itertools
from datetime import datetime
from functools import partial

import os
import torch
import graphviz
import numpy as np
from tqdm import tqdm
import networkx as nx
import more_itertools
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.utils.tensorboard import SummaryWriter
from networkx.algorithms.dag import topological_sort
from torch.utils.data import DataLoader, DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP

from reactnet import reactnet
sys.path.append('.')
from cifar_torch import cifar100
import logging_util

torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True


def make_graph(model):
    graph = nx.DiGraph()
    visited = set()

    def traverse(layer_name):
        visited.add(layer_name)
        for dep in model.dependent_layers(layer_name):
            graph.add_edge(dep, layer_name)
            if dep not in visited:
                traverse(dep)

    traverse('fc')
    return graph


def dataset_cov(train_loader, extract_patches, dim, sample_portion=1.0):
    count = 0
    mean = torch.zeros(dim, device='cuda', dtype=torch.float64)
    cov = torch.zeros(dim, dim, device='cuda', dtype=torch.float64)
    iters = math.ceil(len(train_loader) * sample_portion)
    for X, y in itertools.islice(tqdm(train_loader), iters):
        X = X.cuda(non_blocking=True)
        patches = extract_patches(X)

        other_count = patches.double().size(1)
        other_mean = patches.double().mean(1)
        other_cov = torch.cov(patches)

        if count == 0:
            count = other_count
            mean.copy_(other_mean)
            cov.copy_(other_cov)
            continue

        merged_count = count + other_count
        count_corr = (other_count * count) / merged_count

        flat_mean_diff = other_mean - mean
        mean += flat_mean_diff * other_count / merged_count

        mean_diffs = torch.broadcast_to(flat_mean_diff, cov.shape).T
        cov *= (count / merged_count)
        cov += (
                other_cov * (count / merged_count)
                + mean_diffs * mean_diffs.T * (count_corr / merged_count)
        )
    return cov.float()


class Sign(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x: torch.Tensor):
        ctx.save_for_backward(x)
        return torch.sign(x)

    @staticmethod
    def backward(ctx, g):
        x, = ctx.saved_tensors
        return g * torch.clamp(2 * (1 - torch.abs(x)), min=0.0)


@torch.jit.script
def quantization_loss(q, M, w_hat, lamba: float = 1e-5):
    q_hat = M @ q
    loss = ((q_hat - w_hat) ** 2).sum(0)
    loss += lamba * q.norm(2, dim=0)
    return loss


# knowledge distillation loss
@torch.jit.script
def distillation(logits, labels, teacher_scores, T: float, alpha: float):
    # distillation loss + classification loss
    # y: student
    # labels: hard label
    # teacher_scores: soft label
    task_loss = F.cross_entropy(logits, labels, reduction='none')
    teacher_loss = F.kl_div(
        F.log_softmax(logits / T, dim=-1),
        F.log_softmax(teacher_scores / T, dim=-1),
        reduction='none',
        log_target=True
    ).sum(1)
    return task_loss, task_loss * (1 - alpha) + teacher_loss * (2 * T ** 2 + alpha)


@torch.jit.script
def fused_clip(max_norm: float, total_norm):
    clip_coef = max_norm / (total_norm + 1e-6)
    clip_coef = torch.clip(clip_coef, 0.0, 1.0)
    return clip_coef


@torch.no_grad()
def clip_grad_norm_(
        parameters, max_norm: float, norm_type: float = 2.0) -> torch.Tensor:
    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]
    parameters = [p for p in parameters if p.grad is not None]
    if len(parameters) == 0:
        return torch.tensor(0.)
    if norm_type == float('inf'):
        norms = [p.grad.detach().abs().max() for p in parameters]
        total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms))
    else:
        total_norm = torch.norm(torch.stack([
            torch.norm(p.grad.detach(), norm_type) for p in parameters]), norm_type)
    clip_coef = fused_clip(max_norm, total_norm)
    for p in parameters:
        p.grad.detach().mul_(clip_coef)
    return total_norm


@torch.enable_grad()
def train_model(model, opt, teacher, train_loader, n_steps,
                T=20.0, alpha=0.7):
    teacher = DDP(teacher)
    teacher.eval()
    model = DDP(model)
    model.train()
    steps = 0
    orig_lr = opt.param_groups[0]['lr']
    params = sum([g['params'] for g in opt.param_groups], [])

    while True:
        count = torch.tensor(0, device='cuda')
        task_loss_accum = torch.tensor(0.0, device='cuda')
        loss_accum = torch.tensor(0.0, device='cuda')
        correct_count = torch.tensor(0, device='cuda')
        for X, y in tqdm(train_loader, desc="Train"):
            X = X.cuda(non_blocking=True)
            y = y.cuda(non_blocking=True).flatten()
            model.zero_grad()

            logits = model(X)
            with torch.no_grad():
                teacher_scores = teacher(X)
            task_loss, loss = distillation(logits, y, teacher_scores, T, alpha)
            loss.mean().backward()

            # apply EmpCov to the first layer
            #w = cur_w.data.view(cur_w.shape[0], -1).T  # (in, out)
            #cur_w.grad.add_((inv_sigma @ w).T.reshape_as(cur_w))

            clip_grad_norm_(params, 2.0)  # clip gradient
            for g in opt.param_groups:
                r = steps / n_steps
                g['lr'] = orig_lr * min(r, 1 - r) * 2

            opt.step()
            count += X.shape[0]
            task_loss_accum += task_loss.sum().data
            loss_accum += loss.sum().data
            correct_count += (torch.argmax(logits, dim=1) == y).sum()

            steps += 1
            if steps > n_steps:
                return
        dist.all_reduce(task_loss_accum)
        dist.all_reduce(loss_accum)
        dist.all_reduce(correct_count)
        dist.all_reduce(count)
        tqdm.write(f'task{(task_loss_accum / count).item()} / dist{(loss_accum / count).item()}')
        accuracy = (correct_count / count).item()
        tqdm.write(f'{accuracy:.4f}')


class QuantizationModel(torch.nn.Module):
    def __init__(self, model, log_s, V):
        super(QuantizationModel, self).__init__()
        self.model = model
        self.V = torch.nn.Parameter(V.clone())
        self.log_s = torch.nn.Parameter(log_s.clone())

    def forward(self, x):
        return self.model(x)


@torch.enable_grad()
def quantize_layer(model, teacher, train_loader, n_steps,
                   bn_params, cur_layers, log_s, V, distance,
                   T=20.0, alpha=0.7):

    # freeze all layers
    original_requires_grad = dict()
    for module_name, module in model.named_modules():
        if isinstance(module, torch.nn.BatchNorm2d):
            continue
        for name, p in module.named_parameters(recurse=False):
            original_requires_grad[module_name + '.' + name] = p.requires_grad
            p.requires_grad = False
            p.grad = None

    w = torch.cat([layer.weight.data for layer in cur_layers], dim=0)
    w = w.view(w.shape[0], -1)
    orig_weight_params = []
    for layer in cur_layers:
        orig_weight_params.append(layer.set_mode('quantize'))
    cur_params = []
    for layer in cur_layers:
        cur_params.extend([layer.alpha, layer.scores])

    det = log_s.sum()

    model = QuantizationModel(model, log_s, V)
    log_s, V = model.log_s, model.V

    opt = torch.optim.Adam(
        [{'params': bn_params},
         {'params': cur_params, 'weight_decay': 1e-5},
         {'params': [log_s, V]}],
        lr=1.25e-3 / 4
    )

    teacher = DDP(teacher)
    teacher.eval()
    model = DDP(model)
    model.train()
    steps = 0
    while True:
        count = torch.tensor(0, device='cuda')
        dist_loss_accum = torch.tensor(0.0, device='cuda')
        task_loss_accum = torch.tensor(0.0, device='cuda')
        loss_accum = torch.tensor(0.0, device='cuda')
        correct_count = torch.tensor(0, device='cuda')
        for X, y in tqdm(train_loader, desc="Train"):
            X = X.cuda(non_blocking=True)
            y = y.cuda(non_blocking=True).flatten()
            model.zero_grad()

            logits = model(X)
            with torch.no_grad():
                teacher_scores = teacher(X)
            task_loss, loss = distillation(logits, y, teacher_scores, T, alpha)
            dist_loss = loss

            # add distance metric
            loss = loss + distance(log_s, V, cur_params, w)

            # add regularization
            VVt = (V @ V.T)
            VVt.diagonal().sub_(1.0)
            ortho_reg = VVt.norm(2) ** 2  # squared frobenius norm
            det_reg = ((log_s.sum() - det) / det)**2
            reg = ortho_reg * 0.1 + det_reg
            loss += reg * 10
            tqdm.write(f"ortho_reg={ortho_reg.item()}, det_reg={det_reg.item()}")

            # backward pass
            loss.mean().backward()

            # optimizer step
            opt.step()
            count += X.shape[0]
            task_loss_accum += task_loss.sum().data
            dist_loss_accum += dist_loss.sum().data
            loss_accum += loss.sum().data
            correct_count += (torch.argmax(logits, dim=1) == y).sum()

            steps += 1
            if steps > n_steps:
                break
        if steps > n_steps:
            break

        dist.all_reduce(task_loss_accum)
        dist.all_reduce(dist_loss_accum)
        dist.all_reduce(loss_accum)
        dist.all_reduce(correct_count)
        dist.all_reduce(count)

        tqdm.write(f'task{(task_loss_accum / count).item()} / dist{(dist_loss_accum / count).item()} / reg{(loss_accum / count).item()}')
        accuracy = (correct_count / count).item()
        tqdm.write(f'{accuracy:.4f}')

    for layer, param in zip(cur_layers, orig_weight_params):
        layer.set_mode('original', param)

    # restore frozen state before
    params = dict(model.module.model.named_parameters())
    for name, requires_grad in original_requires_grad.items():
        params[name].requires_grad = requires_grad
        if not requires_grad:
            params[name].grad = None

    return log_s, V


@torch.no_grad()
def extract_patches(model, layer, X):
    X = model.inputs_for(layer, X)
    if isinstance(layer, torch.nn.Conv2d):
        X = F.unfold(
            X,
            kernel_size=layer.kernel_size,
            dilation=layer.dilation,
            padding=layer.padding,
            stride=layer.stride,
        )  # (batch, ch*ks*ks, height*width)
        X = X.permute(1, 0, 2)  # (ch*ks*ks, batch, height*width)
        X = X.reshape(X.shape[0], -1)  # (dim, batch*height*width)
    return X


def weight_distance(log_s, V, wq, w):
    if len(wq) == 2:
        wq = wq[0] * Sign.apply(wq[1])
        wq = wq.view(wq.shape[0], -1)
    else:
        wq = torch.cat([wq[0], wq[2]], dim=0) * Sign.apply(torch.cat([wq[1], wq[3]], dim=0))
        wq = wq.view(wq.shape[0], -1)

    M = torch.diag(torch.exp(log_s)) @ V.T
    w_hat = M @ w.T
    loss = quantization_loss(wq.T, M, w_hat)
    lsum = loss.mean()
    return lsum


@torch.no_grad()
def quantize(model, train_loader, val_loader, block_iters, finetune_epochs, world_size,
             xnor_net=False, skip_until=None,
             block_size=7, sample_portion=1.0, log_dir=None, findq_iters=20, rank=0):
    graph = make_graph(model)
    modules = dict(model.named_modules())
    weight_params = sum([list(m.parameters(recurse=False)) for m in model.modules()
                         if not isinstance(m, torch.nn.BatchNorm2d)], [])
    bn_params = sum([list(m.parameters(recurse=False)) for m in model.modules()
                     if isinstance(m, torch.nn.BatchNorm2d)], [])
    print(f"{len(weight_params)} weight params, {len(bn_params)} bn params")

    teacher = copy.deepcopy(model)
    total_steps = len(train_loader) * finetune_epochs
    n_steps = total_steps // block_iters

    # for layer_name in topological_sort(graph):
    layers = list(topological_sort(graph))
    layer_groups = list(more_itertools.chunked_even(layers, block_size))
    print("Layer groups", layer_groups)

    finetune_lr = 1.25e-3 / 4
    weight_decay = 1e-5

    for group_idx, layer_group in enumerate(layer_groups):

        V_dict = dict()
        log_s_dict = dict()
        for it in range(block_iters):
            print(f"Iteration {it}.")

            """
            params = set()
            for g in opt.param_groups:
                for p in g['params']:
                    params.add(id(p))
            for name, p in model.named_parameters():
                if id(p) in params:
                    print(f"{name} in opt, {p.requires_grad}")
                else:
                    print(f"{name} NOT in opt")
            """

            if it > 0 or group_idx > 0:
                print("Block initial fine-tuning")

                opt = torch.optim.Adam(
                    [{'params': bn_params},
                     {'params': weight_params, 'weight_decay': weight_decay}],
                    lr=finetune_lr
                )

                train_model(model, opt, teacher, train_loader, n_steps=total_steps)
                test_model(model, val_loader)

            new_frozen = []
            for layer_idx, layer_name in enumerate(layer_group):
                if layer_name == skip_until:
                    skip_until = None

                # skip last layer for now
                if layer_name == 'fc':
                    continue

                print("Quantizing", layer_name)
                if layer_name.endswith('_down'):
                    cur_layers = [modules[layer_name + '1'], modules[layer_name + '2']]
                else:
                    cur_layers = [modules[layer_name]]
                cur_params = [layer.weight for layer in cur_layers]
                groups = cur_layers[0].groups
                get_patches = partial(extract_patches, model, cur_layers[0])
                dim = np.prod(cur_layers[0].weight.shape[1:])

                # freeze layer
                new_frozen.extend(cur_layers)
                for p in cur_params:
                    p.requires_grad = False
                    p.grad = None

                # FIXME skip depthwise layer for now
                if groups > 1:
                    continue
                # skip quantization until specified layer is reached
                if skip_until is not None:
                    continue

                # calculate input covariance
                model.eval()
                #if it == 0:
                if True:
                    cov = dataset_cov(train_loader, get_patches, dim, sample_portion)
                    _, s, V_dict[layer_name] = torch.svd(cov)
                    log_s_dict[layer_name] = torch.log(s + 1e-6)
                log_s = log_s_dict[layer_name]
                V = V_dict[layer_name]

                # quantize layer
                print("Finding quantized values")
                new_log_s, new_V = quantize_layer(
                    model, teacher, train_loader, n_steps,
                    bn_params=bn_params, cur_layers=cur_layers,
                    distance=weight_distance, log_s=log_s, V=V
                )
                log_s_dict[layer_name].copy_(new_log_s)
                V_dict[layer_name].copy_(new_V)

                test_model(model, val_loader)

                if layer_idx < len(layer_group) - 1:
                    print("Fine-tuning")

                    opt = torch.optim.Adam(
                        [{'params': bn_params},
                         {'params': weight_params, 'weight_decay': weight_decay}],
                        lr=finetune_lr
                    )

                    train_model(model, opt, teacher, train_loader, n_steps=n_steps)
                    test_model(model, val_loader)

            print("Block final fine-tuning")

            opt = torch.optim.Adam(
                [{'params': bn_params},
                 {'params': weight_params, 'weight_decay': weight_decay}],
                lr=finetune_lr
            )

            train_model(model, opt, teacher, train_loader, n_steps=total_steps)
            test_model(model, val_loader)

            if rank == 0:
                print("Saving to ", log_dir / f"block-{group_idx}.pth")
                torch.save(model.state_dict(), log_dir / f"block-{group_idx}.pth")

            # unfreeze block layers
            if it < block_iters - 1:
                print(f"Unfreezing {len(new_frozen)} layers")
                for layer in new_frozen:
                    layer.weight.requires_grad = True
            else:
                weight_params = [p for p in weight_params if p.requires_grad]

        finetune_lr /= 2.0


def load_model(path):
    model = reactnet(num_classes=100).cuda()
    checkpoint = torch.load(path, map_location='cpu')
    print(checkpoint['epoch'], checkpoint['best_top1_acc'])
    state_dict = {k.replace('module.', ''): v for k, v in checkpoint['state_dict'].items()}
    model.load_state_dict(state_dict, strict=False)
    return model


def train_acc_model(model, train_loader: DataLoader):
    model = DDP(model)
    model.train()
    print("Train Acc")
    count = torch.tensor(0, device='cuda')
    loss_accum = torch.tensor(0.0, device='cuda')
    correct_count = torch.tensor(0, device='cuda')
    criterion = torch.nn.CrossEntropyLoss(reduction='sum')
    with torch.no_grad():
        for X, y in train_loader:
            X = X.cuda(non_blocking=True)
            y = y.cuda(non_blocking=True).flatten()
            logits = model(X)
            loss = criterion(logits, y)
            count += X.shape[0]
            loss_accum += loss.data
            correct_count += (torch.argmax(logits, dim=1) == y).sum()
        dist.all_reduce(loss_accum)
        dist.all_reduce(correct_count)
        dist.all_reduce(count)
        loss = (loss_accum / count).item()
        acc = (correct_count / count).item()
        print(f'{loss}')
        print(f'{acc:.4f}')
    return loss, acc


def test_model(model, val_loader: DataLoader):
    model = DDP(model)
    model.eval()
    print("Test")
    count = torch.tensor(0, device='cuda')
    loss_accum = torch.tensor(0.0, device='cuda')
    correct_count = torch.tensor(0, device='cuda')
    criterion = torch.nn.CrossEntropyLoss(reduction='sum')
    with torch.no_grad():
        for X, y in val_loader:
            X = X.cuda(non_blocking=True)
            y = y.cuda(non_blocking=True).flatten()
            logits = model(X)
            loss = criterion(logits, y)
            count += X.shape[0]
            loss_accum += loss.data
            correct_count += (torch.argmax(logits, dim=1) == y).sum()
        dist.all_reduce(loss_accum)
        dist.all_reduce(correct_count)
        dist.all_reduce(count)
        loss = (loss_accum / count).item()
        acc = (correct_count / count).item()
        print(f'{loss}')
        print(f'{acc:.4f}')
    return loss, acc


def test_model_orig(model, val_loader: DataLoader):
    model.eval()
    print("Test")
    count = torch.tensor(0, device='cuda')
    loss_accum = torch.tensor(0.0, device='cuda')
    correct_count = torch.tensor(0, device='cuda')
    criterion = torch.nn.CrossEntropyLoss(reduction='sum')
    with torch.no_grad():
        for X, y in val_loader:
            X = X.cuda(non_blocking=True)
            y = y.cuda(non_blocking=True).flatten()
            logits = model(X)
            loss = criterion(logits, y)
            count += X.shape[0]
            loss_accum += loss.data
            correct_count += (torch.argmax(logits, dim=1) == y).sum()
        loss = (loss_accum / count).item()
        acc = (correct_count / count).item()
        print(f'{loss}')
        print(f'{acc:.4f}')
    return loss, acc


def render_dep_graph(model):
    # render dependency graph
    dot = graphviz.Graph('dep-graph')
    marked = set()

    def make_dep_tree(model, layer):
        dot.node(layer, layer)
        marked.add(layer)
        for dep in model.dependent_layers(layer):
            dot.edge(dep, layer)
            if dep not in marked:
                make_dep_tree(model, dep)

    make_dep_tree(model, 'fc')
    dot.render('dep-graph-reactnet')


def run(rank, world_size, args):
    writer = SummaryWriter(log_dir=args.log_dir)
    logging_util.setup_logging(
        pathlib.Path(args.log_dir) / f'log_rank{rank}.txt',
        log_to_screen=(rank == 0)
    )
    print("===== Args =====")
    for name, value in vars(args).items():
        print(f"   {name}: {value}")
    print("================")

    print(f"Using GPU {rank}")
    torch.cuda.set_device(f"cuda:{rank}")

    # construct model
    model = load_model(args.load_path)
    # render_dep_graph(model)
    print(model)

    # load dataset
    train_loader, val_loader = cifar100(batch_size=200, workers=4, distributed=True)

    # evaluated fp model
    train_acc_model(model, train_loader)
    test_model(model, val_loader)

    # quantize model
    quantize(model, train_loader, val_loader,
             block_iters=args.block_iters,
             xnor_net=args.xnor_net,
             skip_until=args.skip_until,
             finetune_epochs=args.finetune_epochs,
             block_size=args.block_size,
             sample_portion=args.sample_portion,
             findq_iters=args.findq_iters,
             log_dir=pathlib.Path(writer.log_dir),
             rank=rank, world_size=world_size)

    # evaluate quantized model
    train_acc_model(model, train_loader)
    test_model(model, val_loader)

    if rank == 0:
        torch.save(model.state_dict(), args.save_path)


def dist_main(rank, world_size, args):
    os.environ['MASTER_ADDR'] = 'localhost'
    if 'MASTER_PORT' not in os.environ:
        os.environ['MASTER_PORT'] = '12355'

    # initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    run(rank, world_size, args)
    dist.destroy_process_group()


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--xnor_net', default=False, action='store_true')
    parser.add_argument('--skip_until', default=None, type=str)
    parser.add_argument('--load_path', default='/d1/xxx/DBQ/baseline/1_step1/models_run1/model_best.pth.tar', type=str)
    parser.add_argument('--block_iters', default=5, type=int)
    parser.add_argument('--finetune_epochs', default=10, type=int)
    parser.add_argument('--block_size', default=7, type=int)
    parser.add_argument('--sample_portion', default=1.0, type=float)
    parser.add_argument('--findq_iters', default=20, type=int)
    args = parser.parse_args()

    current_time = datetime.now().strftime('%b%d_%H-%M-%S-%f')
    args.log_dir = pathlib.Path('runs') / f"reactnet_{current_time}_{socket.gethostname()}"
    load_path = pathlib.Path(args.load_path)
    args.save_path = args.log_dir / f"{load_path.stem}-quantized.pt"
    print("Will save to ", args.save_path)

    source_copy_dir = args.log_dir / 'source'
    source_copy_dir.mkdir(parents=True, exist_ok=False)
    current_dir = pathlib.Path(__file__).parent
    py_files = glob.glob(str(current_dir / '*.py'))  # + glob.glob(str(current_dir / '**' / '*.py'))
    for f in py_files:
        relpath = pathlib.Path(f).relative_to(current_dir)
        print(f"Copy {relpath} to {source_copy_dir}")
        target = source_copy_dir / relpath
        target.parent.mkdir(parents=True, exist_ok=True)
        shutil.copy(f, target)

    world_size = torch.cuda.device_count()
    mp.set_start_method('forkserver')
    mp.spawn(dist_main, args=(world_size, args), nprocs=world_size, join=True)


if __name__ == '__main__':
    main()
