import argparse
import copy
import glob
import math
import os
import pathlib
import socket
import sys
from datetime import datetime

import more_itertools
import networkx as nx
import shutil
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn.functional as F
from networkx.algorithms.dag import topological_sort
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

from reactnet_new import reactnet, QuantizedConv2d
from extract_cov import dataset_cov

sys.path.append('.')
from cifar_torch import cifar100
import logging_util

torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True


def make_graph(model):
    graph = nx.DiGraph()
    visited = set()

    def traverse(layer_name):
        visited.add(layer_name)
        for dep in model.dependent_layers(layer_name):
            graph.add_edge(dep, layer_name)
            if dep not in visited:
                traverse(dep)

    traverse('fc')
    return graph


class Sign(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x: torch.Tensor):
        ctx.save_for_backward(x)
        return torch.sign(x)

    @staticmethod
    def backward(ctx, g):
        x, = ctx.saved_tensors
        return g * torch.clamp(2 * (1 - torch.abs(x)), min=0.0), None


@torch.jit.script
def quantization_loss(q, M, w_hat, lamba: float = 1e-5):
    q_hat = M @ q
    loss = ((q_hat - w_hat) ** 2).sum(0)
    loss += lamba * q.norm(2, dim=0)
    return loss


def eigh(cov):
    try:
        return torch.linalg.eigh(cov)
    except torch._C._LinAlgError:  # workaround for torch bug
        return torch.linalg.eigh(cov)


@torch.enable_grad()
def find_q(layer_names, layers, log_s_dict, V_dict,
           orig_weight_dict,
           nsteps=2000, nsamples=20, xnor_net=False):
    params = []
    for layer in layers:
        params.extend([layer.alpha, layer.scores])
    opt = torch.optim.SGD(params, lr=0.01, momentum=0.9, nesterov=True)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=nsteps)

    layers_dict = dict(zip(layer_names, layers))

    best_losses = torch.full([len(layer_names)], 10.**9, device='cuda')
    best_alpha_dict = {
        layer_name: copy.deepcopy(layer.alpha.data)
        for layer_name, layer in layers_dict.items()
    }
    best_scores_dict = {
        layer_name: copy.deepcopy(layer.scores.data)
        for layer_name, layer in layers_dict.items()
    }

    for _ in tqdm(range(nsamples), desc='sample'):
        for step in tqdm(range(nsteps), desc='step'):
            opt.zero_grad()

            losses = []
            for layer_name in layer_names:
                alpha = layers_dict[layer_name].alpha
                scores = layers_dict[layer_name].scores
                log_s = log_s_dict[layer_name]
                V = V_dict[layer_name]

                orig_w = orig_weight_dict[layer_name].data
                orig_w = orig_w.reshape(orig_w.shape[0], -1).T  # (in, out)

                M = torch.exp(log_s)[:, None] * V.T
                w_hat = M @ orig_w

                q = alpha * Sign.apply(scores)
                q = q.reshape(q.shape[0], -1).T  # (in, out)

                loss = quantization_loss(q, M, w_hat).mean()
                losses.append(loss)

            losses = torch.stack(losses)
            total_loss = losses.sum()
            total_loss.backward()

            # update best
            conds = torch.le(losses, best_losses)
            best_losses.data.copy_(torch.where(conds, losses, best_losses))
            for cond, layer_name in zip(conds, layer_names):
                scores, best_scores = layers_dict[layer_name].scores, best_scores_dict[layer_name]
                alpha, best_alpha = layers_dict[layer_name].alpha, best_alpha_dict[layer_name]
                best_scores.data.copy_(torch.where(cond, scores.data, best_scores.data))
                best_alpha.data.copy_(torch.where(cond, alpha.data, best_alpha.data))

            opt.step()
            scheduler.step()

        for i, layer_name in enumerate(layer_names):
            scores = layers_dict[layer_name].scores
            scores.data.add_(torch.randn_like(scores) * 0.01)

        print("best=", best_losses.mean().item())

    best_list = [torch.zeros_like(best_losses) for _ in range(dist.get_world_size())]
    dist.all_gather(best_list, best_losses)

    best_ranks = torch.argmin(torch.stack(best_list), dim=0).cpu().numpy()
    for layer_name, best_rank in zip(layer_names, best_ranks):
        scores, best_scores = layers_dict[layer_name].scores, best_scores_dict[layer_name]
        alpha, best_alpha = layers_dict[layer_name].alpha, best_alpha_dict[layer_name]
        dist.broadcast(best_scores, best_rank)
        dist.broadcast(best_alpha, best_rank)
        alpha.data.copy_(best_alpha)
        scores.data.copy_(best_scores)

    final_bests = torch.amin(torch.stack(best_list), dim=0)
    print(f"final best={final_bests.mean().item()}")


# knowledge distillation loss
@torch.jit.script
def distillation(logits, labels, teacher_scores, T: float, alpha: float):
    # distillation loss + classification loss
    # y: student
    # labels: hard label
    # teacher_scores: soft label
    task_loss = F.cross_entropy(logits, labels, reduction='none')
    teacher_loss = F.kl_div(
        F.log_softmax(logits / T, dim=-1),
        F.log_softmax(teacher_scores / T, dim=-1),
        reduction='none',
        log_target=True
    ).sum(1)
    return task_loss, task_loss * (1 - alpha) + teacher_loss * (2 * T ** 2 + alpha)


@torch.jit.script
def fused_clip(max_norm: float, total_norm):
    clip_coef = max_norm / (total_norm + 1e-6)
    clip_coef = torch.clip(clip_coef, 0.0, 1.0)
    return clip_coef


@torch.no_grad()
def clip_grad_norm_(
        parameters, max_norm: float, norm_type: float = 2.0) -> torch.Tensor:
    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]
    parameters = [p for p in parameters if p.grad is not None]
    if len(parameters) == 0:
        return torch.tensor(0.)
    if norm_type == float('inf'):
        norms = [p.grad.detach().abs().max() for p in parameters]
        total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms))
    else:
        total_norm = torch.norm(torch.stack([
            torch.norm(p.grad.detach(), norm_type) for p in parameters]), norm_type)
    clip_coef = fused_clip(max_norm, total_norm)
    for p in parameters:
        p.grad.detach().mul_(clip_coef)
    return total_norm


def SV_regularization(log_s_dict, V_dict, orig_log_s_sum_dict):
    regs = []

    for key in log_s_dict.keys():
        log_s = log_s_dict[key]
        V = V_dict[key]
        det = orig_log_s_sum_dict[key]

        VVt = (V @ V.T)
        VVt.diagonal().sub_(1.0)
        ortho_reg = VVt.norm(2) ** 2  # squared frobenius norm
        det_reg = ((log_s.sum() - det) / det) ** 2
        reg = ortho_reg * 0.1 + det_reg
        regs.append(reg)

    regs = torch.stack(regs)
    # print("regs=", regs)

    return regs.sum()


@torch.enable_grad()
def find_S_V(model,
             orig_log_s_sum_dict, log_s_dict, V_dict,
             opt, teacher, train_loader, n_steps,
             T=20.0, alpha=0.7):
    teacher.eval()
    model = DDP(model)
    model.train()
    steps = 0

    while True:
        count = torch.tensor(0, device='cuda')
        task_loss_accum = torch.tensor(0.0, device='cuda')
        dist_loss_accum = torch.tensor(0.0, device='cuda')
        loss_accum = torch.tensor(0.0, device='cuda')
        correct_count = torch.tensor(0, device='cuda')
        for X, y in tqdm(train_loader, desc="Train"):
            X = X.cuda(non_blocking=True)
            y = y.cuda(non_blocking=True).flatten()
            model.zero_grad()

            logits = model(X)
            with torch.no_grad():
                teacher_scores = teacher(X)
            task_loss, dist_loss = distillation(logits, y, teacher_scores, T, alpha)

            print(dist_loss.mean().item())

            # add regularization
            reg_loss = dist_loss + SV_regularization(log_s_dict, V_dict, orig_log_s_sum_dict) * 10

            reg_loss.mean().backward()

            # clip_grad_norm_(params, 2.0)  # clip gradient

            opt.step()
            count += X.shape[0]
            task_loss_accum += task_loss.sum().data
            dist_loss_accum += dist_loss.sum().data
            loss_accum += reg_loss.sum().data
            correct_count += (torch.argmax(logits, dim=1) == y).sum()

            steps += 1
            new_noise_scale = 10 ** (-2 + 2 * (steps / n_steps))
            print(f"noise_scale = {new_noise_scale}")
            for module in model.modules():
                if isinstance(module, QuantizedConv2d):
                    module.noise_scale = new_noise_scale

            if steps > n_steps:
                return
        dist.all_reduce(task_loss_accum)
        dist.all_reduce(dist_loss_accum)
        dist.all_reduce(loss_accum)
        dist.all_reduce(correct_count)
        dist.all_reduce(count)
        tqdm.write(f'task{(task_loss_accum / count).item()} / '
                   f'dist{(dist_loss_accum / count).item()} / '
                   f'reg{(loss_accum / count).item()}')
        accuracy = (correct_count / count).item()
        tqdm.write(f'{accuracy:.4f}')


@torch.no_grad()
def quantize(model, train_loader, val_loader,
             sample_portion=1.0, findq_iters=20,
             log_dir=None):
    teacher = copy.deepcopy(model)

    graph = make_graph(model)
    modules = dict(model.named_modules())
    weight_params = sum([list(m.parameters(recurse=False)) for m in model.modules()
                         if not isinstance(m, torch.nn.BatchNorm2d)], [])
    bn_params = sum([list(m.parameters(recurse=False)) for m in model.modules()
                     if isinstance(m, torch.nn.BatchNorm2d)], [])
    print(f"{len(weight_params)} weight params, {len(bn_params)} bn params")

    layer_names_all = list(topological_sort(graph))[:-1]

    block_size = 4
    layer_groups = list(more_itertools.chunked_even(layer_names_all, block_size))
    print("Layer groups", layer_groups)

    for group_idx, layer_group in enumerate(layer_groups):
        print(f"Group {group_idx}: {layer_group}")

        layers = []
        for layer_idx, layer_name in enumerate(layer_group):
            cur_layer = modules[layer_name]
            layers.append(cur_layer)

        cov_iter = None
        covs, layer_ids, cov_iter = dataset_cov(
            train_loader, cov_iter,
            model, layers, layer_group,
            sample_portion
        )

        max_rank = 10240
        log_s_dict, V_dict = {}, {}
        orig_log_s_sum_dict = {}
        for layer_name, cov in covs.items():
            s, new_V = eigh(cov)
            s = s[-max_rank:]
            new_V = new_V[:, -max_rank:]
            new_log_s = torch.log(torch.clip(s, min=1e-6)) * 0.5
            orig_log_s_sum_dict[layer_name] = new_log_s.sum()
            log_s_dict[layer_name] = torch.nn.Parameter(new_log_s)
            V_dict[layer_name] = torch.nn.Parameter(new_V)

        orig_weight_dict = {}
        alpha_scores_dict = {}
        for layer_name, layer in zip(layer_group, layers):
            w = layer.set_mode('quantize')
            orig_weight_dict[layer_name] = w
            alpha, scores = layer.set_noise_injection(
                True, original_w=w.data,
                log_s=log_s_dict[layer_name], V=V_dict[layer_name]
            )
            alpha_scores_dict[layer_name] = (alpha, scores)

        # find S,V
        print("Finding S,V")

        later_weights = []
        for other_layer_group in layer_groups[group_idx + 1:]:
            for other_layer_name in other_layer_group:
                layer = modules[other_layer_name]
                later_weights.append(layer.weight)

        optimizer = torch.optim.Adam(
            [{'params': list(log_s_dict.values())},
             {'params': list(V_dict.values())},
             {'params': later_weights, 'lr': 1.25e-3 / 4, 'weight_decay': 1e-5},
             {'params': bn_params, 'lr': 1.25e-3 / 4}],
            lr=1.25e-3 * 10
        )

        find_S_V(
            model=model,
            orig_log_s_sum_dict=orig_log_s_sum_dict,
            log_s_dict=log_s_dict,
            V_dict=V_dict,
            opt=optimizer,
            teacher=teacher,
            train_loader=train_loader,
            n_steps=len(train_loader) * 20
        )

        for layer_name, layer in zip(layer_group, layers):
            alpha, scores = alpha_scores_dict[layer_name]
            layer.set_noise_injection(False, alpha=alpha, scores=scores)

        # test model
        test_model(model, val_loader)
        #torch.save(model, log_dir / f"intermediate-grp{group_idx}.pt")

        # find W_q
        print("Finding W")
        find_q(layer_group, layers, log_s_dict, V_dict,
               orig_weight_dict,
               nsteps=1000, nsamples=20)

        torch.save(model, log_dir / f"intermediate2-grp{group_idx}.pt")

        # test model
        test_model(model, val_loader)


def load_model(path):
    model = reactnet(num_classes=100).cuda()
    checkpoint = torch.load(path, map_location='cuda')
    print(checkpoint['epoch'], checkpoint['best_top1_acc'])
    state_dict = {k.replace('module.', ''): v for k, v in checkpoint['state_dict'].items()}
    for key in list(state_dict.keys()):
        if key.endswith('down1.weight'):
            down2_key = key.replace('down1', 'down2')
            new_key = key.replace('down1', 'down')
            state_dict[new_key] = torch.cat([state_dict[key], state_dict[down2_key]], dim=0)
    model.load_state_dict(state_dict, strict=False)
    return model


def train_acc_model(model, train_loader: DataLoader):
    model = DDP(model)
    model.train()
    print("Train Acc")
    count = torch.tensor(0, device='cuda')
    loss_accum = torch.tensor(0.0, device='cuda')
    correct_count = torch.tensor(0, device='cuda')
    criterion = torch.nn.CrossEntropyLoss(reduction='sum')
    with torch.no_grad():
        for X, y in train_loader:
            X = X.cuda(non_blocking=True)
            y = y.cuda(non_blocking=True).flatten()
            logits = model(X)
            loss = criterion(logits, y)
            count += X.shape[0]
            loss_accum += loss.data
            correct_count += (torch.argmax(logits, dim=1) == y).sum()
        dist.all_reduce(loss_accum)
        dist.all_reduce(correct_count)
        dist.all_reduce(count)
        loss = (loss_accum / count).item()
        acc = (correct_count / count).item()
        print(f'{loss}')
        print(f'{acc:.4f}')
    return loss, acc


def test_model(model, val_loader: DataLoader):
    model = DDP(model)
    model.eval()
    print("Test")
    count = torch.tensor(0, device='cuda')
    loss_accum = torch.tensor(0.0, device='cuda')
    correct_count = torch.tensor(0, device='cuda')
    criterion = torch.nn.CrossEntropyLoss(reduction='sum')
    with torch.no_grad():
        for X, y in val_loader:
            X = X.cuda(non_blocking=True)
            y = y.cuda(non_blocking=True).flatten()
            logits = model(X)
            loss = criterion(logits, y)
            count += X.shape[0]
            loss_accum += loss.data
            correct_count += (torch.argmax(logits, dim=1) == y).sum()
        dist.all_reduce(loss_accum)
        dist.all_reduce(correct_count)
        dist.all_reduce(count)
        loss = (loss_accum / count).item()
        acc = (correct_count / count).item()
        print(f'{loss}')
        print(f'{acc:.4f}')
    return loss, acc


def test_model_orig(model, val_loader: DataLoader):
    model.eval()
    print("Test")
    count = torch.tensor(0, device='cuda')
    loss_accum = torch.tensor(0.0, device='cuda')
    correct_count = torch.tensor(0, device='cuda')
    criterion = torch.nn.CrossEntropyLoss(reduction='sum')
    with torch.no_grad():
        for X, y in val_loader:
            X = X.cuda(non_blocking=True)
            y = y.cuda(non_blocking=True).flatten()
            logits = model(X)
            loss = criterion(logits, y)
            count += X.shape[0]
            loss_accum += loss.data
            correct_count += (torch.argmax(logits, dim=1) == y).sum()
        loss = (loss_accum / count).item()
        acc = (correct_count / count).item()
        print(f'{loss}')
        print(f'{acc:.4f}')
    return loss, acc


def run(rank, world_size, args):
    writer = SummaryWriter(log_dir=args.log_dir)
    logging_util.setup_logging(
        pathlib.Path(args.log_dir) / f'log_rank{rank}.txt',
        log_to_screen=(rank == 0)
    )
    print("===== Args =====")
    for name, value in vars(args).items():
        print(f"   {name}: {value}")
    print("================")

    print(f"Using GPU {rank}")
    torch.cuda.set_device(f"cuda:{rank}")
    torch.manual_seed(rank)

    # construct model
    model = load_model(args.load_path)
    # render_dep_graph(model)
    print(model)

    # load dataset
    train_loader, val_loader = cifar100(batch_size=128, workers=4, distributed=True)

    # evaluate fp model
    # train_acc_model(model, train_loader)
    # test_model(model, val_loader)

    # quantize model
    quantize(model, train_loader, val_loader,
             sample_portion=args.sample_portion,
             findq_iters=args.findq_iters,
             log_dir=pathlib.Path(writer.log_dir))

    # evaluate quantized model
    train_acc_model(model, train_loader)
    test_model(model, val_loader)

    if rank == 0:
        torch.save(model.state_dict(), args.save_path)


def dist_main(rank, world_size, args):
    os.environ['MASTER_ADDR'] = 'localhost'
    if 'MASTER_PORT' not in os.environ:
        os.environ['MASTER_PORT'] = '12355'

    # initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    run(rank, world_size, args)
    dist.destroy_process_group()


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--xnor_net', default=False, action='store_true')
    parser.add_argument('--skip_until', default=None, type=str)
    parser.add_argument('--load_path', default='/d1/xxx/DBQ/baseline/1_step1/models_run1/model_best.pth.tar', type=str)
    parser.add_argument('--block_iters', default=5, type=int)
    parser.add_argument('--finetune_epochs', default=10, type=int)
    parser.add_argument('--block_size', default=7, type=int)
    parser.add_argument('--sample_portion', default=1.0, type=float)
    parser.add_argument('--findq_iters', default=20, type=int)
    args = parser.parse_args()

    current_time = datetime.now().strftime('%b%d_%H-%M-%S-%f')
    args.log_dir = pathlib.Path('runs') / f"reactnet_{current_time}_{socket.gethostname()}"
    load_path = pathlib.Path(args.load_path)
    args.save_path = args.log_dir / f"{load_path.stem}-quantized.pt"
    print("Will save to ", args.save_path)

    source_copy_dir = args.log_dir / 'source'
    source_copy_dir.mkdir(parents=True, exist_ok=False)
    current_dir = pathlib.Path(__file__).parent
    py_files = glob.glob(str(current_dir / '*.py'))  # + glob.glob(str(current_dir / '**' / '*.py'))
    for f in py_files:
        relpath = pathlib.Path(f).relative_to(current_dir)
        print(f"Copy {relpath} to {source_copy_dir}")
        target = source_copy_dir / relpath
        target.parent.mkdir(parents=True, exist_ok=True)
        shutil.copy(f, target)

    world_size = torch.cuda.device_count()
    mp.set_start_method('forkserver')
    mp.spawn(dist_main, args=(world_size, args), nprocs=world_size, join=True)


if __name__ == '__main__':
    main()
