import argparse
import copy
import glob
import math
import os
import pathlib
import socket
import sys
from datetime import datetime

import networkx as nx
import shutil
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn.functional as F
from networkx.algorithms.dag import topological_sort
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

from reactnet import reactnet, QuantizedConv2d
from extract_cov import dataset_cov

sys.path.append('.')
from cifar_torch import cifar100
import logging_util

torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True


def make_graph(model):
    graph = nx.DiGraph()
    visited = set()

    def traverse(layer_name):
        visited.add(layer_name)
        for dep in model.dependent_layers(layer_name):
            graph.add_edge(dep, layer_name)
            if dep not in visited:
                traverse(dep)

    traverse('fc')
    return graph


class Sign(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x: torch.Tensor):
        ctx.save_for_backward(x)
        return torch.sign(x)

    @staticmethod
    def backward(ctx, g):
        x, = ctx.saved_tensors
        return g * torch.clamp(2 * (1 - torch.abs(x)), min=0.0), None


@torch.jit.script
def quantization_loss(q, M, w_hat, lamba: float = 1e-5):
    q_hat = M @ q
    loss = ((q_hat - w_hat) ** 2).sum(0)
    loss += lamba * q.norm(2, dim=0)
    return loss


def eigh(cov):
    try:
        return torch.linalg.eigh(cov)
    except torch._C._LinAlgError:  # workaround for torch bug
        return torch.linalg.eigh(cov)


@torch.enable_grad()
def find_q(layer_names, layers, log_s_dict, V_dict,
           orig_weight_dict,
           nsteps=2000, nsamples=20, xnor_net=False):
    params = []
    for layer in layers:
        params.extend([layer.alpha, layer.scores])
    opt = torch.optim.SGD(params, lr=0.01, momentum=0.9, nesterov=True)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=nsteps)

    layers_dict = dict(zip(layer_names, layers))

    best_losses = torch.full([len(layer_names)], 10.**9, device='cuda')
    best_alpha_dict = {
        layer_name: copy.deepcopy(layer.alpha.data)
        for layer_name, layer in layers_dict.items()
    }
    best_scores_dict = {
        layer_name: copy.deepcopy(layer.scores.data)
        for layer_name, layer in layers_dict.items()
    }

    for _ in tqdm(range(nsamples), desc='sample'):
        for step in tqdm(range(nsteps), desc='step'):
            opt.zero_grad()

            losses = []
            for layer_name in layer_names:
                alpha = layers_dict[layer_name].alpha
                scores = layers_dict[layer_name].scores
                log_s = log_s_dict[layer_name]
                V = V_dict[layer_name]

                orig_w = orig_weight_dict[layer_name].data
                orig_w = orig_w.reshape(orig_w.shape[0], -1).T  # (in, out)

                M = torch.exp(log_s)[:, None] * V.T
                w_hat = M @ orig_w

                q = alpha * Sign.apply(scores)
                q = q.reshape(q.shape[0], -1).T  # (in, out)

                loss = quantization_loss(q, M, w_hat).mean()
                losses.append(loss)

            losses = torch.stack(losses)
            total_loss = losses.sum()
            total_loss.backward()

            # update best
            conds = torch.le(losses, best_losses)
            best_losses.data.copy_(torch.where(conds, losses, best_losses))
            for cond, layer_name in zip(conds, layer_names):
                scores, best_scores = layers_dict[layer_name].scores, best_scores_dict[layer_name]
                alpha, best_alpha = layers_dict[layer_name].alpha, best_alpha_dict[layer_name]
                best_scores.data.copy_(torch.where(cond, scores.data, best_scores.data))
                best_alpha.data.copy_(torch.where(cond, alpha.data, best_alpha.data))

            opt.step()
            scheduler.step()

        for i, layer_name in enumerate(layer_names):
            scores = layers_dict[layer_name].scores
            scores.data.add_(torch.randn_like(scores) * 0.01)

        print("best=", best_losses.mean().item())

    best_list = [torch.zeros_like(best_losses) for _ in range(dist.get_world_size())]
    dist.all_gather(best_list, best_losses)

    best_ranks = torch.argmin(torch.stack(best_list), dim=0).cpu().numpy()
    for layer_name, best_rank in zip(layer_names, best_ranks):
        scores, best_scores = layers_dict[layer_name].scores, best_scores_dict[layer_name]
        alpha, best_alpha = layers_dict[layer_name].alpha, best_alpha_dict[layer_name]
        dist.broadcast(best_scores, best_rank)
        dist.broadcast(best_alpha, best_rank)
        alpha.data.copy_(best_alpha)
        scores.data.copy_(best_scores)

    final_bests = torch.amin(torch.stack(best_list), dim=0)
    print(f"final best={final_bests.mean().item()}")


# knowledge distillation loss
@torch.jit.script
def distillation(logits, labels, teacher_scores, T: float, alpha: float):
    # distillation loss + classification loss
    # y: student
    # labels: hard label
    # teacher_scores: soft label
    task_loss = F.cross_entropy(logits, labels, reduction='none')
    teacher_loss = F.kl_div(
        F.log_softmax(logits / T, dim=-1),
        F.log_softmax(teacher_scores / T, dim=-1),
        reduction='none',
        log_target=True
    ).sum(1)
    return task_loss, task_loss * (1 - alpha) + teacher_loss * (2 * T ** 2 + alpha)


@torch.jit.script
def fused_clip(max_norm: float, total_norm):
    clip_coef = max_norm / (total_norm + 1e-6)
    clip_coef = torch.clip(clip_coef, 0.0, 1.0)
    return clip_coef


@torch.no_grad()
def clip_grad_norm_(
        parameters, max_norm: float, norm_type: float = 2.0) -> torch.Tensor:
    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]
    parameters = [p for p in parameters if p.grad is not None]
    if len(parameters) == 0:
        return torch.tensor(0.)
    if norm_type == float('inf'):
        norms = [p.grad.detach().abs().max() for p in parameters]
        total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms))
    else:
        total_norm = torch.norm(torch.stack([
            torch.norm(p.grad.detach(), norm_type) for p in parameters]), norm_type)
    clip_coef = fused_clip(max_norm, total_norm)
    for p in parameters:
        p.grad.detach().mul_(clip_coef)
    return total_norm


def SV_regularization(log_s_dict, V_dict, orig_log_s_sum_dict):
    regs = []

    for key in log_s_dict.keys():
        log_s = log_s_dict[key]
        V = V_dict[key]
        det = orig_log_s_sum_dict[key]

        VVt = (V @ V.T)
        VVt.diagonal().sub_(1.0)
        ortho_reg = VVt.norm(2) ** 2  # squared frobenius norm
        det_reg = ((log_s.sum() - det) / det) ** 2
        reg = ortho_reg * 0.1 + det_reg
        regs.append(reg)

    regs = torch.stack(regs)
    # print("regs=", regs)

    return regs.sum()


@torch.enable_grad()
def find_S_V(model,
             orig_log_s_sum_dict, log_s_dict, V_dict,
             opt, teacher, train_loader, n_steps,
             T=20.0, alpha=0.7):
    teacher.eval()
    model = DDP(model)
    model.train()
    steps = 0
    params = [*log_s_dict.values(), *V_dict.values()]

    while True:
        count = torch.tensor(0, device='cuda')
        task_loss_accum = torch.tensor(0.0, device='cuda')
        dist_loss_accum = torch.tensor(0.0, device='cuda')
        loss_accum = torch.tensor(0.0, device='cuda')
        correct_count = torch.tensor(0, device='cuda')
        for X, y in tqdm(train_loader, desc="Train"):
            X = X.cuda(non_blocking=True)
            y = y.cuda(non_blocking=True).flatten()
            model.zero_grad()

            logits = model(X)
            with torch.no_grad():
                teacher_scores = teacher(X)
            task_loss, dist_loss = distillation(logits, y, teacher_scores, T, alpha)

            print(dist_loss.mean().item())

            # add regularization
            reg_loss = dist_loss + SV_regularization(log_s_dict, V_dict, orig_log_s_sum_dict) * 10

            reg_loss.backward()

            # clip_grad_norm_(params, 2.0)  # clip gradient

            opt.step()
            count += X.shape[0]
            task_loss_accum += task_loss.sum().data
            dist_loss_accum += dist_loss.sum().data
            loss_accum += reg_loss.sum().data
            correct_count += (torch.argmax(logits, dim=1) == y).sum()

            steps += 1
            new_noise_scale = 10 ** (-2 + 2 * (steps / n_steps))
            print(f"noise_scale = {new_noise_scale}")
            for module in model.modules():
                if isinstance(module, QuantizedConv2d):
                    module.noise_scale = new_noise_scale

            if steps > n_steps:
                return
        dist.all_reduce(task_loss_accum)
        dist.all_reduce(dist_loss_accum)
        dist.all_reduce(loss_accum)
        dist.all_reduce(correct_count)
        dist.all_reduce(count)
        tqdm.write(f'task{(task_loss_accum / count).item()} / '
                   f'dist{(dist_loss_accum / count).item()} / '
                   f'reg{(loss_accum / count).item()}')
        accuracy = (correct_count / count).item()
        tqdm.write(f'{accuracy:.4f}')


@torch.no_grad()
def quantize(model, train_loader, val_loader,
             sample_portion=1.0, findq_iters=20,
             log_dir=None):
    teacher = copy.deepcopy(model)

    graph = make_graph(model)
    modules = dict(model.named_modules())
    weight_params = sum([list(m.parameters(recurse=False)) for m in model.modules()
                         if not isinstance(m, torch.nn.BatchNorm2d)], [])
    bn_params = sum([list(m.parameters(recurse=False)) for m in model.modules()
                     if isinstance(m, torch.nn.BatchNorm2d)], [])
    print(f"{len(weight_params)} weight params, {len(bn_params)} bn params")

    layer_names = list(topological_sort(graph))[:-1]

    # layer_names = layer_names[-4:]

    layers = []
    for layer_idx, layer_name in enumerate(layer_names):
        cur_layer = modules[layer_name]
        layers.append(cur_layer)

    cov_iter = None
    covs, layer_ids, cov_iter = dataset_cov(
        train_loader, cov_iter,
        model, layers, layer_names,
        sample_portion
    )

    max_rank = 10240
    log_s_dict, V_dict = {}, {}
    orig_log_s_sum_dict = {}
    for layer_name, cov in covs.items():
        s, new_V = eigh(cov)
        s = s[-max_rank:]
        new_V = new_V[:, -max_rank:]
        new_log_s = torch.log(torch.clip(s, min=1e-6)) * 0.5
        orig_log_s_sum_dict[layer_name] = new_log_s.sum()
        log_s_dict[layer_name] = torch.nn.Parameter(new_log_s)
        V_dict[layer_name] = torch.nn.Parameter(new_V)

    orig_weight_dict = {}
    for layer_name, layer in zip(layer_names, layers):
        w = layer.set_mode('quantize-reparam',
                           log_s=log_s_dict[layer_name],
                           V=V_dict[layer_name])
        orig_weight_dict[layer_name] = w


    while True:
        # find S,V
        print("Finding S,V")

        optimizer = torch.optim.Adam(
            [{'params': list(log_s_dict.values())},
             {'params': list(V_dict.values())}],
            lr=1.25e-3 * 10
        )

        find_S_V(
            model=model,
            orig_log_s_sum_dict=orig_log_s_sum_dict,
            log_s_dict=log_s_dict,
            V_dict=V_dict,
            opt=optimizer,
            teacher=teacher,
            train_loader=train_loader,
            n_steps=len(train_loader) * 20
        )

        torch.save(model, log_dir / "intermediate.pt")

        # find W_q
        print("Finding W")
        find_q(layer_names, layers, log_s_dict, V_dict,
               orig_weight_dict,
               nsteps=1000, nsamples=20)

        torch.save(model, log_dir / "intermediate_2.pt")

        # test model
        test_model(model, val_loader)


def load_model(path):
    model = reactnet(num_classes=100).cuda()
    checkpoint = torch.load(path, map_location='cpu')
    print(checkpoint['epoch'], checkpoint['best_top1_acc'])
    state_dict = {k.replace('module.', ''): v for k, v in checkpoint['state_dict'].items()}
    for key in list(state_dict.keys()):
        if key.endswith('down1.weight'):
            down2_key = key.replace('down1', 'down2')
            new_key = key.replace('down1', 'down')
            state_dict[new_key] = torch.cat([state_dict[key], state_dict[down2_key]], dim=0)
    model.load_state_dict(state_dict, strict=False)
    return model


def train_acc_model(model, train_loader: DataLoader):
    model = DDP(model)
    model.train()
    print("Train Acc")
    count = torch.tensor(0, device='cuda')
    loss_accum = torch.tensor(0.0, device='cuda')
    correct_count = torch.tensor(0, device='cuda')
    criterion = torch.nn.CrossEntropyLoss(reduction='sum')
    with torch.no_grad():
        for X, y in train_loader:
            X = X.cuda(non_blocking=True)
            y = y.cuda(non_blocking=True).flatten()
            logits = model(X)
            loss = criterion(logits, y)
            count += X.shape[0]
            loss_accum += loss.data
            correct_count += (torch.argmax(logits, dim=1) == y).sum()
        dist.all_reduce(loss_accum)
        dist.all_reduce(correct_count)
        dist.all_reduce(count)
        loss = (loss_accum / count).item()
        acc = (correct_count / count).item()
        print(f'{loss}')
        print(f'{acc:.4f}')
    return loss, acc


def test_model(model, val_loader: DataLoader):
    model = DDP(model)
    model.eval()
    print("Test")
    count = torch.tensor(0, device='cuda')
    loss_accum = torch.tensor(0.0, device='cuda')
    correct_count = torch.tensor(0, device='cuda')
    criterion = torch.nn.CrossEntropyLoss(reduction='sum')
    with torch.no_grad():
        for X, y in val_loader:
            X = X.cuda(non_blocking=True)
            y = y.cuda(non_blocking=True).flatten()
            logits = model(X)
            loss = criterion(logits, y)
            count += X.shape[0]
            loss_accum += loss.data
            correct_count += (torch.argmax(logits, dim=1) == y).sum()
        dist.all_reduce(loss_accum)
        dist.all_reduce(correct_count)
        dist.all_reduce(count)
        loss = (loss_accum / count).item()
        acc = (correct_count / count).item()
        print(f'{loss}')
        print(f'{acc:.4f}')
    return loss, acc


def test_model_orig(model, val_loader: DataLoader):
    model.eval()
    print("Test")
    count = torch.tensor(0, device='cuda')
    loss_accum = torch.tensor(0.0, device='cuda')
    correct_count = torch.tensor(0, device='cuda')
    criterion = torch.nn.CrossEntropyLoss(reduction='sum')
    with torch.no_grad():
        for X, y in val_loader:
            X = X.cuda(non_blocking=True)
            y = y.cuda(non_blocking=True).flatten()
            logits = model(X)
            loss = criterion(logits, y)
            count += X.shape[0]
            loss_accum += loss.data
            correct_count += (torch.argmax(logits, dim=1) == y).sum()
        loss = (loss_accum / count).item()
        acc = (correct_count / count).item()
        print(f'{loss}')
        print(f'{acc:.4f}')
    return loss, acc


def run(rank, world_size, args):
    writer = SummaryWriter(log_dir=args.log_dir)
    logging_util.setup_logging(
        pathlib.Path(args.log_dir) / f'log_rank{rank}.txt',
        log_to_screen=(rank == 0)
    )
    print("===== Args =====")
    for name, value in vars(args).items():
        print(f"   {name}: {value}")
    print("================")

    print(f"Using GPU {rank}")
    torch.cuda.set_device(f"cuda:{rank}")
    torch.manual_seed(rank)

    # construct model
    model = load_model(args.load_path)
    # render_dep_graph(model)
    print(model)

    # load dataset
    train_loader, val_loader = cifar100(batch_size=128, workers=4, distributed=True)

    # evaluate fp model
    # train_acc_model(model, train_loader)
    # test_model(model, val_loader)

    # quantize model
    quantize(model, train_loader, val_loader,
             sample_portion=args.sample_portion,
             findq_iters=args.findq_iters,
             log_dir=pathlib.Path(writer.log_dir))

    # evaluate quantized model
    train_acc_model(model, train_loader)
    test_model(model, val_loader)

    if rank == 0:
        torch.save(model.state_dict(), args.save_path)


def dist_main(rank, world_size, args):
    os.environ['MASTER_ADDR'] = 'localhost'
    if 'MASTER_PORT' not in os.environ:
        os.environ['MASTER_PORT'] = '12355'

    # initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    run(rank, world_size, args)
    dist.destroy_process_group()


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--xnor_net', default=False, action='store_true')
    parser.add_argument('--skip_until', default=None, type=str)
    parser.add_argument('--load_path', default='/d1/xxx/DBQ/baseline/1_step1/models_run1/model_best.pth.tar', type=str)
    parser.add_argument('--block_iters', default=5, type=int)
    parser.add_argument('--finetune_epochs', default=10, type=int)
    parser.add_argument('--block_size', default=7, type=int)
    parser.add_argument('--sample_portion', default=1.0, type=float)
    parser.add_argument('--findq_iters', default=20, type=int)
    args = parser.parse_args()

    current_time = datetime.now().strftime('%b%d_%H-%M-%S-%f')
    args.log_dir = pathlib.Path('runs') / f"reactnet_{current_time}_{socket.gethostname()}"
    load_path = pathlib.Path(args.load_path)
    args.save_path = args.log_dir / f"{load_path.stem}-quantized.pt"
    print("Will save to ", args.save_path)

    source_copy_dir = args.log_dir / 'source'
    source_copy_dir.mkdir(parents=True, exist_ok=False)
    current_dir = pathlib.Path(__file__).parent
    py_files = glob.glob(str(current_dir / '*.py'))  # + glob.glob(str(current_dir / '**' / '*.py'))
    for f in py_files:
        relpath = pathlib.Path(f).relative_to(current_dir)
        print(f"Copy {relpath} to {source_copy_dir}")
        target = source_copy_dir / relpath
        target.parent.mkdir(parents=True, exist_ok=True)
        shutil.copy(f, target)

    world_size = torch.cuda.device_count()
    mp.set_start_method('forkserver')
    mp.spawn(dist_main, args=(world_size, args), nprocs=world_size, join=True)


if __name__ == '__main__':
    main()
