import sys
import copy
import glob
import math
import socket
import shutil
import pathlib
import argparse
import itertools
from datetime import datetime
from functools import partial

import os
import torch
import numpy as np
from tqdm import tqdm
import networkx as nx
import more_itertools
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.utils.tensorboard import SummaryWriter
from networkx.algorithms.dag import topological_sort
from torch.utils.data import DataLoader, DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP

from reactnet import reactnet
sys.path.append('.')
from cifar_torch import cifar100, cifar10
from extract_cov import combine_cov
from kmeans import kmeans
import logging_util
from resnet_pcifar import ResNet18

torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True


def make_graph(model):
    graph = nx.DiGraph()
    visited = set()

    def traverse(layer_name):
        visited.add(layer_name)
        for dep in model.dependent_layers(layer_name):
            graph.add_edge(dep, layer_name)
            if dep not in visited:
                traverse(dep)

    traverse('fc')
    return graph


def dataset_cov(train_loader, loader_instance, permute_indices, get_patches, dim, sample_portion=1.0):
    count = torch.zeros((), device='cuda', dtype=torch.long)
    mean = torch.zeros(dim, device='cuda', dtype=torch.float64)
    cov = torch.zeros(dim, dim, device='cuda', dtype=torch.float64)
    iters = math.ceil(len(train_loader) * sample_portion)
    if loader_instance is None:
        loader_instance = iter(train_loader)
    for _ in range(iters):
        try:
            X, y = next(loader_instance)
        except StopIteration:
            loader_instance = iter(train_loader)
            X, y = next(loader_instance)
        X = X.cuda(non_blocking=True)
        patches = get_patches(X)  # (dim, batch*height*width)
        if permute_indices is not None:
            patches = patches[permute_indices]

        other_count = torch.tensor(patches.size(1), dtype=torch.long)
        other_mean = patches.mean(1)
        other_cov = torch.cov(patches)

        if count == 0:
            count.copy_(other_count)
            mean.copy_(other_mean)
            cov.copy_(other_cov)
        else:
            combine_cov(
                count, mean, cov,
                other_count, other_mean, other_cov
            )

    other_count = count.new_empty(count.shape)
    other_mean = mean.new_empty(mean.shape)
    other_cov = cov.new_empty(cov.shape)

    rank, world_size = dist.get_rank(), dist.get_world_size()
    gap = 2
    for stage in range(math.ceil(math.log2(world_size))):
        dst = rank // 2 * gap
        src = dst + gap // 2
        if rank % 2 == 0:
            if src < world_size:
                dist.recv(other_count, src)
                dist.recv(other_mean, src)
                dist.recv(other_cov, src)
                combine_cov(
                    count, mean, cov,
                    other_count, other_mean, other_cov
                )
        else:
            dist.send(count, dst)
            dist.send(mean, dst)
            dist.send(cov, dst)
            break
        rank //= 2
        gap *= 2

    dist.broadcast(count, 0)
    dist.broadcast(mean, 0)
    dist.broadcast(cov, 0)

    return cov.float(), loader_instance


class Sign(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x: torch.Tensor):
        ctx.save_for_backward(x)
        return torch.sign(x)

    @staticmethod
    def backward(ctx, g):
        x, = ctx.saved_tensors
        return g * torch.clamp(2 * (1 - torch.abs(x)), min=0.0)


@torch.jit.script
def quantization_loss(q, M, w_hat, lamba: float = 1e-5):
    q_hat = M @ q
    loss = ((q_hat - w_hat) ** 2).sum(0)
    loss += lamba * q.norm(2, dim=0)
    return loss


# knowledge distillation loss
@torch.jit.script
def distillation(logits, labels, teacher_scores, T: float, alpha: float):
    # distillation loss + classification loss
    # y: student
    # labels: hard label
    # teacher_scores: soft label
    task_loss = F.cross_entropy(logits, labels, reduction='none')
    teacher_loss = F.kl_div(
        F.log_softmax(logits / T, dim=-1),
        F.log_softmax(teacher_scores / T, dim=-1),
        reduction='none',
        log_target=True
    ).sum(1)
    return task_loss, task_loss * (1 - alpha) + teacher_loss * (2 * T ** 2 + alpha)


@torch.jit.script
def fused_clip(max_norm: float, total_norm):
    clip_coef = max_norm / (total_norm + 1e-6)
    clip_coef = torch.clip(clip_coef, 0.0, 1.0)
    return clip_coef


@torch.no_grad()
def clip_grad_norm_(
        parameters, max_norm: float, norm_type: float = 2.0) -> torch.Tensor:
    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]
    parameters = [p for p in parameters if p.grad is not None]
    if len(parameters) == 0:
        return torch.tensor(0.)
    if norm_type == float('inf'):
        norms = [p.grad.detach().abs().max() for p in parameters]
        total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms))
    else:
        total_norm = torch.norm(torch.stack([
            torch.norm(p.grad.detach(), norm_type) for p in parameters]), norm_type)
    clip_coef = fused_clip(max_norm, total_norm)
    for p in parameters:
        p.grad.detach().mul_(clip_coef)
    return total_norm


@torch.enable_grad()
def train_model(model, opt, teacher, train_loader, n_steps,
                T=20.0, alpha=0.7):
    teacher = DDP(teacher)
    teacher.eval()
    model = DDP(model)
    model.train()
    steps = 0
    orig_lr = opt.param_groups[0]['lr']
    params = sum([g['params'] for g in opt.param_groups], [])

    while True:
        count = torch.tensor(0, device='cuda')
        task_loss_accum = torch.tensor(0.0, device='cuda')
        loss_accum = torch.tensor(0.0, device='cuda')
        correct_count = torch.tensor(0, device='cuda')
        for X, y in tqdm(train_loader, desc="Train"):
            X = X.cuda(non_blocking=True)
            y = y.cuda(non_blocking=True).flatten()
            model.zero_grad()

            logits = model(X)
            with torch.no_grad():
                teacher_scores = teacher(X)
            task_loss, loss = distillation(logits, y, teacher_scores, T, alpha)
            loss.mean().backward()

            # apply EmpCov to the first layer
            #w = cur_w.data.view(cur_w.shape[0], -1).T  # (in, out)
            #cur_w.grad.add_((inv_sigma @ w).T.reshape_as(cur_w))

            clip_grad_norm_(params, 2.0)  # clip gradient
            for g in opt.param_groups:
                r = steps / n_steps
                g['lr'] = orig_lr * min(r, 1 - r) * 2

            opt.step()
            count += X.shape[0]
            task_loss_accum += task_loss.sum().data
            loss_accum += loss.sum().data
            correct_count += (torch.argmax(logits, dim=1) == y).sum()

            steps += 1
            if steps > n_steps:
                return
        dist.all_reduce(task_loss_accum)
        dist.all_reduce(loss_accum)
        dist.all_reduce(correct_count)
        dist.all_reduce(count)
        tqdm.write(f'task{(task_loss_accum / count).item()} / dist{(loss_accum / count).item()}')
        accuracy = (correct_count / count).item()
        tqdm.write(f'{accuracy:.4f}')


class QuantizationModel(torch.nn.Module):
    def __init__(self, model, log_s, V):
        super(QuantizationModel, self).__init__()
        self.model = model
        self.V = torch.nn.Parameter(V.clone())
        self.log_s = torch.nn.Parameter(log_s.clone())

    def forward(self, x):
        return self.model(x)


def padded_stack(ws):
    max_n_out = max(w.shape[0] for w in ws)
    sum_n_in = sum(w.shape[1] for w in ws)
    stacked = torch.zeros((max_n_out, sum_n_in), device='cuda')
    offset = 0
    for w in ws:
        n_out = w.shape[0]
        n_in = w.shape[1]
        stacked[:n_out, offset:offset+n_in] = w
        offset += n_in
    return stacked


@torch.enable_grad()
def quantize_layer(model, teacher, train_loader, n_steps,
                   bn_params, cur_layers, log_s, V, distance,
                   permute_indices=None,
                   add_distance=True,
                   T=20.0, distill_alpha=0.7, max_cov_dim=256):

    # freeze all layers
    original_requires_grad = dict()
    for module_name, module in model.named_modules():
        if isinstance(module, torch.nn.BatchNorm2d):
            continue
        for name, p in module.named_parameters(recurse=False):
            original_requires_grad[module_name + '.' + name] = p.requires_grad
            p.requires_grad = False
            p.grad = None

    flat_layers_list = []
    orig_weight_params = []
    cur_params = []

    ws = []
    for layers, permute in zip(cur_layers, permute_indices):
        w = torch.cat([layer.weight.data for layer in layers], dim=0)
        w = w.view(w.shape[0], -1)  # (output, input)
        if permute is not None:
            assert w.shape[1] == permute.shape[0]
            w = w[:, permute]

        if w.shape[1] > max_cov_dim:
            stride = w.shape[1] // max_cov_dim
            stride = next_power_of_2(stride)
            assert w.shape[1] % stride == 0
            w = F.avg_pool1d(w[None], stride, stride)[0]

        ws.append(w)
        for layer in layers:
            flat_layers_list.append(layer)
            orig_weight_params.append(layer.set_mode('quantize'))
            cur_params.extend([layer.alpha, layer.scores])
    ws = padded_stack(ws)

    det = log_s.sum()

    orig_model = model
    if add_distance:
        model = QuantizationModel(model, log_s, V)
        log_s, V = model.log_s, model.V

    opt = torch.optim.Adam(
        [{'params': bn_params},
         {'params': cur_params, 'weight_decay': 1e-5},
         {'params': [log_s, V]}],
        lr=1.25e-3 / 4
    )

    teacher = DDP(teacher)
    teacher.eval()
    model = DDP(model)
    model.train()
    steps = 0
    while True:
        count = torch.tensor(0, device='cuda')
        dist_loss_accum = torch.tensor(0.0, device='cuda')
        task_loss_accum = torch.tensor(0.0, device='cuda')
        loss_accum = torch.tensor(0.0, device='cuda')
        correct_count = torch.tensor(0, device='cuda')
        for X, y in tqdm(train_loader, desc="Train"):
            X = X.cuda(non_blocking=True)
            y = y.cuda(non_blocking=True).flatten()
            model.zero_grad()

            logits = model(X)
            with torch.no_grad():
                teacher_scores = teacher(X)
            task_loss, loss = distillation(logits, y, teacher_scores, T, distill_alpha)
            dist_loss = loss

            # add distance metric
            if add_distance:
                wqs = []
                for layers, permute in zip(cur_layers, permute_indices):
                    alpha = torch.cat([l.alpha for l in layers], dim=0)
                    scores = torch.cat([l.scores for l in layers], dim=0)
                    wq = alpha * scores
                    wq = wq.view(wq.shape[0], -1)

                    if permute is not None:
                        assert wq.shape[1] == permute.shape[0]
                        wq = wq[:, permute]

                    if wq.shape[1] > max_cov_dim:
                        stride = wq.shape[1] // max_cov_dim
                        stride = next_power_of_2(stride)
                        assert wq.shape[1] % stride == 0
                        wq = F.avg_pool1d(wq[None], stride, stride)[0]

                    wqs.append(wq)
                wqs = padded_stack(wqs)
                dist_value = distance(log_s, V, wqs, ws)
                loss = loss + dist_value

                # add regularization
                VVt = (V @ V.T)
                VVt.diagonal().sub_(1.0)
                ortho_reg = VVt.norm(2) ** 2  # squared frobenius norm
                det_reg = ((log_s.sum() - det) / det)**2
                reg = ortho_reg * 0.1 + det_reg
                loss += reg * 10
                tqdm.write(f"ortho_reg={ortho_reg.item()}, det_reg={det_reg.item()}")

            # backward pass
            loss.mean().backward()

            # optimizer step
            opt.step()
            count += X.shape[0]
            task_loss_accum += task_loss.sum().data
            dist_loss_accum += dist_loss.sum().data
            loss_accum += loss.sum().data
            correct_count += (torch.argmax(logits, dim=1) == y).sum()

            steps += 1
            if steps > n_steps:
                break
        if steps > n_steps:
            break

        dist.all_reduce(task_loss_accum)
        dist.all_reduce(dist_loss_accum)
        dist.all_reduce(loss_accum)
        dist.all_reduce(correct_count)
        dist.all_reduce(count)

        tqdm.write(f'task{(task_loss_accum / count).item()} / '
                   f'dist{(dist_loss_accum / count).item()} / '
                   f'reg{(loss_accum / count).item()}')
        accuracy = (correct_count / count).item()
        tqdm.write(f'{accuracy:.4f}')

    # restore layer mode
    for layer, param in zip(flat_layers_list, orig_weight_params):
        layer.set_mode('original', param)

    # restore frozen state before
    params = dict(orig_model.named_parameters())
    for name, requires_grad in original_requires_grad.items():
        params[name].requires_grad = requires_grad
        if not requires_grad:
            params[name].grad = None

    return log_s, V


def weight_distance(log_s, V, wq, w):
    M = torch.diag(torch.exp(log_s)) @ V.T
    w_hat = M @ w.T
    loss = quantization_loss(wq.T, M, w_hat)
    lsum = loss.mean()
    return lsum


def expand_S_V(old_log_s: torch.Tensor, old_V: torch.Tensor,
               new_log_s: torch.Tensor, new_V: torch.Tensor):
    dim = new_V.shape[0]
    old_dim = old_V.shape[0]
    new_dim = old_dim + dim

    large_log_s = old_log_s.new_zeros((new_dim,))
    large_log_s[:old_dim] = old_log_s
    large_log_s[-dim:] = new_log_s

    large_V = old_V.new_zeros((new_dim, new_dim))
    large_V[:old_dim, :old_dim] = old_V
    large_V[-dim:, -dim:] = new_V

    return large_log_s, large_V


def eigh(cov):
    try:
        return torch.linalg.eigh(cov)
    except torch._C._LinAlgError:
        return torch.linalg.eigh(cov)


def next_power_of_2(x):
    return 1 if x == 0 else 2**math.ceil(math.log2(x))


@torch.no_grad()
def reorder_channels(train_loader, loader_instance, get_patches, dim, num_clusters, total_size=2**25):
    features = None
    offset = 0
    steps = 0
    if loader_instance is None:
        loader_instance = iter(train_loader)
    while True:
        try:
            X, y = next(loader_instance)
        except StopIteration:
            loader_instance = iter(train_loader)
            X, y = next(loader_instance)
        X = X.cuda(non_blocking=True)

        patches = get_patches(X)  # (dim, batch*height*width)

        if features is None:
            size = min(total_size // dim, len(train_loader) * patches.shape[1])
            features = torch.zeros([dim, size], device='cuda')

        n = min(patches.shape[1], features.shape[1] - offset)
        features[:, offset:offset+n] = patches[:, :n]
        offset += n
        steps += 1

        if offset == features.shape[1]:
            break

    print("Running Kmeans")
    cluster_ids, _ = kmeans(
        X=features, num_clusters=num_clusters, distance='euclidean', device=torch.device('cuda')
    )
    sort_indices = torch.argsort(cluster_ids)
    assert sort_indices.shape[0] == dim

    return sort_indices, loader_instance


@torch.no_grad()
def quantize(model, train_loader, val_loader, args,
             block_iters, finetune_epochs, world_size,
             xnor_net=False, no_distance=False, skip_until=None,
             block_size=7, sample_portion=1.0, log_dir=None, findq_iters=20, rank=0, max_cov_dim=256):
    graph = make_graph(model)
    modules = dict(model.named_modules())
    weight_params = sum([list(m.parameters(recurse=False)) for m in model.modules()
                         if not isinstance(m, torch.nn.BatchNorm2d)], [])
    bn_params = sum([list(m.parameters(recurse=False)) for m in model.modules()
                     if isinstance(m, torch.nn.BatchNorm2d)], [])
    print(f"{len(weight_params)} weight params, {len(bn_params)} bn params")

    teacher = copy.deepcopy(model)

    total_steps = len(train_loader) * finetune_epochs
    n_steps = total_steps // block_iters

    # for layer_name in topological_sort(graph):
    layers = list(topological_sort(graph))
    layer_groups = list(more_itertools.chunked_even(layers, block_size))
    print("Layer groups", layer_groups)

    finetune_lr = 1.25e-3 / 4
    weight_decay = 1e-5

    cov_iter = None

    for group_idx, layer_group in enumerate(layer_groups):
        print(f"Group {group_idx}: {layer_group}")

        for it in range(block_iters):
            print(f"Iteration {it}.")

            if it > 0 or group_idx > 0 and (skip_until is None):
                print("Block initial fine-tuning")

                opt = torch.optim.Adam(
                    [{'params': bn_params},
                     {'params': weight_params, 'weight_decay': weight_decay}],
                    lr=finetune_lr
                )

                train_model(model, opt, teacher, train_loader, n_steps=total_steps)
                test_model(model, val_loader)

            V = torch.zeros((0, 0), device='cuda')
            log_s = torch.zeros((0,), device='cuda')

            new_frozen = []
            permute_indices_group = []
            for layer_idx, layer_name in enumerate(layer_group):
                if layer_name == skip_until:
                    skip_until = None

                # skip last layer for now
                if layer_name == 'fc':
                    continue

                print("Quantizing", layer_name)
                if layer_name.endswith('_down'):
                    cur_layers = [modules[layer_name + '1'], modules[layer_name + '2']]
                else:
                    cur_layers = [modules[layer_name]]
                cur_params = [layer.weight for layer in cur_layers]
                groups = cur_layers[0].groups
                get_patches = partial(extract_patches, model, cur_layers[0])
                dim = np.prod(cur_layers[0].weight.shape[1:])

                # freeze layer
                new_frozen.append(cur_layers)
                for p in cur_params:
                    p.requires_grad = False
                    p.grad = None

                # FIXME skip depthwise layer for now
                if groups > 1:
                    continue
                # skip quantization until specified layer is reached
                if skip_until is not None:
                    continue

                if dim > max_cov_dim:
                    stride = dim // max_cov_dim
                    stride = next_power_of_2(stride)
                    print(f"stride = {stride}")
                    assert dim % stride == 0
                else:
                    stride = 1

                permute_indices = None
                num_clusters = dim
                if stride > 1:
                    num_clusters = dim // stride
                    permute_indices, cov_iter = reorder_channels(
                        train_loader, cov_iter, get_patches, dim, num_clusters)
                    dist.broadcast(permute_indices, 0)
                permute_indices_group.append(permute_indices)

                if args.random_sv:
                    new_V = torch.randn(num_clusters, num_clusters, device='cuda')
                    new_log_s = torch.randn(num_clusters, device='cuda')
                else:
                    # calculate input covariance
                    model.eval()
                    cov, cov_iter = dataset_cov(
                        train_loader, cov_iter, permute_indices, get_patches, dim, sample_portion)

                    if stride > 1:
                        cov = F.avg_pool2d(cov[None, None], stride, stride)[0, 0]

                    s, new_V = eigh(cov)
                    new_log_s = torch.log(torch.clip(s, min=0.0) + 1e-6)

                # expand covariance matrix
                log_s, V = expand_S_V(log_s, V, new_log_s, new_V)

                # quantize layer
                print("Finding quantized values")
                log_s, V = quantize_layer(
                    model, teacher, train_loader, n_steps,
                    bn_params=bn_params, cur_layers=new_frozen,
                    distance=weight_distance, log_s=log_s, V=V,
                    permute_indices=permute_indices_group,
                    add_distance=(not no_distance),
                    max_cov_dim=max_cov_dim
                )

                test_model(model, val_loader)

                if layer_idx < len(layer_group) - 1:
                    print("Fine-tuning")

                    opt = torch.optim.Adam(
                        [{'params': bn_params},
                         {'params': weight_params, 'weight_decay': weight_decay}],
                        lr=finetune_lr
                    )

                    train_model(model, opt, teacher, train_loader, n_steps=n_steps)
                    test_model(model, val_loader)

            if skip_until is None:
                print("Block final fine-tuning")

                opt = torch.optim.Adam(
                    [{'params': bn_params},
                     {'params': weight_params, 'weight_decay': weight_decay}],
                    lr=finetune_lr
                )

                train_model(model, opt, teacher, train_loader, n_steps=total_steps)
                test_model(model, val_loader)

            if rank == 0:
                print("Saving to ", log_dir / f"block-{group_idx}.pth")
                torch.save(model.state_dict(), log_dir / f"block-{group_idx}.pth")

            # unfreeze block layers
            if it < block_iters - 1:
                print(f"Unfreezing {len(new_frozen)} layers")
                for layers in new_frozen:
                    for layer in layers:
                        layer.weight.requires_grad = True
            else:
                weight_params = [p for p in weight_params if p.requires_grad]

        print("Saving s, V to ", log_dir / f"sV-{group_idx}.pth")
        torch.save({'log_s': log_s, 'V': V, 'permute': permute_indices_group}, log_dir / f"sV-{group_idx}.pth")


@torch.no_grad()
def extract_patches(model, layer, X):
    X = model.inputs_for(layer, X)
    if isinstance(layer, torch.nn.Conv2d):
        X = F.unfold(
            X,
            kernel_size=layer.kernel_size,
            dilation=layer.dilation,
            padding=layer.padding,
            stride=layer.stride,
        )  # (batch, ch*ks*ks, height*width)
        X = X.permute(1, 0, 2)  # (ch*ks*ks, batch, height*width)
        X = X.reshape(X.shape[0], -1)  # (dim, batch*height*width)
    return X


def load_model(path):
    model = reactnet(num_classes=100).cuda()
    checkpoint = torch.load(path, map_location='cpu')
    print(checkpoint['epoch'], checkpoint['best_top1_acc'])
    state_dict = {k.replace('module.', ''): v for k, v in checkpoint['state_dict'].items()}
    model.load_state_dict(state_dict, strict=False)
    return model


def load_model_cifar10(path):
    model = reactnet(num_classes=10).cuda()
    checkpoint = torch.load(path, map_location='cpu')
    print(checkpoint['epoch'], checkpoint['best_top1_acc'])
    state_dict = {k.replace('module.', ''): v for k, v in checkpoint['state_dict'].items()}
    model.load_state_dict(state_dict, strict=False)
    return model


def train_acc_model(model, train_loader: DataLoader):
    model = DDP(model)
    model.train()
    print("Train Acc")
    count = torch.tensor(0, device='cuda')
    loss_accum = torch.tensor(0.0, device='cuda')
    correct_count = torch.tensor(0, device='cuda')
    criterion = torch.nn.CrossEntropyLoss(reduction='sum')
    with torch.no_grad():
        for X, y in train_loader:
            X = X.cuda(non_blocking=True)
            y = y.cuda(non_blocking=True).flatten()
            logits = model(X)
            loss = criterion(logits, y)
            count += X.shape[0]
            loss_accum += loss.data
            correct_count += (torch.argmax(logits, dim=1) == y).sum()
        dist.all_reduce(loss_accum)
        dist.all_reduce(correct_count)
        dist.all_reduce(count)
        loss = (loss_accum / count).item()
        acc = (correct_count / count).item()
        print(f'{loss}')
        print(f'{acc:.4f}')
    return loss, acc


def test_model(model, val_loader: DataLoader):
    model = DDP(model)
    model.eval()
    print("Test")
    count = torch.tensor(0, device='cuda')
    loss_accum = torch.tensor(0.0, device='cuda')
    correct_count = torch.tensor(0, device='cuda')
    criterion = torch.nn.CrossEntropyLoss(reduction='sum')
    with torch.no_grad():
        for X, y in val_loader:
            X = X.cuda(non_blocking=True)
            y = y.cuda(non_blocking=True).flatten()
            logits = model(X)
            loss = criterion(logits, y)
            count += X.shape[0]
            loss_accum += loss.data
            correct_count += (torch.argmax(logits, dim=1) == y).sum()
        dist.all_reduce(loss_accum)
        dist.all_reduce(correct_count)
        dist.all_reduce(count)
        loss = (loss_accum / count).item()
        acc = (correct_count / count).item()
        print(f'{loss}')
        print(f'{acc:.4f}')
    return loss, acc


def test_model_orig(model, val_loader: DataLoader):
    model.eval()
    print("Test")
    count = torch.tensor(0, device='cuda')
    loss_accum = torch.tensor(0.0, device='cuda')
    correct_count = torch.tensor(0, device='cuda')
    criterion = torch.nn.CrossEntropyLoss(reduction='sum')
    with torch.no_grad():
        for X, y in val_loader:
            X = X.cuda(non_blocking=True)
            y = y.cuda(non_blocking=True).flatten()
            logits = model(X)
            loss = criterion(logits, y)
            count += X.shape[0]
            loss_accum += loss.data
            correct_count += (torch.argmax(logits, dim=1) == y).sum()
        loss = (loss_accum / count).item()
        acc = (correct_count / count).item()
        print(f'{loss}')
        print(f'{acc:.4f}')
    return loss, acc


def run(rank, world_size, args):
    writer = SummaryWriter(log_dir=args.log_dir)
    logging_util.setup_logging(
        pathlib.Path(args.log_dir) / f'log_rank{rank}.txt',
        log_to_screen=(rank == 0)
    )
    print("===== Args =====")
    for name, value in vars(args).items():
        print(f"   {name}: {value}")
    print("================")

    print(f"Using GPU {rank}")
    torch.cuda.set_device(f"cuda:{rank}")

    # construct model
    if args.dataset == 'cifar100':
        model = load_model(args.load_path)
    elif args.dataset == 'cifar10':
        model = load_model_cifar10(args.load_path)
    # render_dep_graph(model)
    print(model)

    # load dataset
    if args.dataset == 'cifar100':
        train_loader, val_loader = cifar100(batch_size=200, workers=4, distributed=True)
    elif args.dataset == 'cifar10':
        train_loader, val_loader = cifar10(batch_size=200, workers=4, distributed=True)
    else:
        assert False

    # evaluated fp model
    train_acc_model(model, train_loader)
    test_model(model, val_loader)

    # quantize model
    quantize(model, train_loader, val_loader, args,
             block_iters=args.block_iters,
             xnor_net=args.xnor_net,
             skip_until=args.skip_until,
             finetune_epochs=args.finetune_epochs,
             block_size=args.block_size,
             sample_portion=args.sample_portion,
             findq_iters=args.findq_iters,
             no_distance=args.no_distance,
             log_dir=pathlib.Path(writer.log_dir),
             rank=rank, world_size=world_size,
             max_cov_dim=args.max_cov_dim)

    # evaluate quantized model
    train_acc_model(model, train_loader)
    test_model(model, val_loader)

    if rank == 0:
        torch.save(model.state_dict(), args.save_path)


def dist_main(rank, world_size, args):
    os.environ['MASTER_ADDR'] = 'localhost'
    if 'MASTER_PORT' not in os.environ:
        os.environ['MASTER_PORT'] = '12355'

    # initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    run(rank, world_size, args)
    dist.destroy_process_group()


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--xnor_net', default=False, action='store_true')
    parser.add_argument('--skip_until', default=None, type=str)
    parser.add_argument('--load_path', default='/d1/xxx/DBQ/baseline/1_step1/models_run1/model_best.pth.tar', type=str)
    parser.add_argument('--block_iters', default=1, type=int)
    parser.add_argument('--finetune_epochs', default=40, type=int)
    parser.add_argument('--block_size', default=4, type=int)
    parser.add_argument('--sample_portion', default=1.0, type=float)
    parser.add_argument('--findq_iters', default=20, type=int)
    parser.add_argument('--no_distance', default=False, action='store_true')
    parser.add_argument('--max_cov_dim', default=256, type=int)
    parser.add_argument('--dataset', default='cifar100', type=str)
    parser.add_argument('--random_sv', default=False, action='store_true')

    args = parser.parse_args()

    current_time = datetime.now().strftime('%b%d_%H-%M-%S-%f')
    args.log_dir = pathlib.Path('runs') / f"reactnet_{current_time}_{socket.gethostname()}"
    load_path = pathlib.Path(args.load_path)
    args.save_path = args.log_dir / f"{load_path.stem}-quantized.pt"
    print("Will save to ", args.save_path)

    source_copy_dir = args.log_dir / 'source'
    source_copy_dir.mkdir(parents=True, exist_ok=False)
    current_dir = pathlib.Path(__file__).parent
    py_files = glob.glob(str(current_dir / '*.py'))  # + glob.glob(str(current_dir / '**' / '*.py'))
    for f in py_files:
        relpath = pathlib.Path(f).relative_to(current_dir)
        print(f"Copy {relpath} to {source_copy_dir}")
        target = source_copy_dir / relpath
        target.parent.mkdir(parents=True, exist_ok=True)
        shutil.copy(f, target)

    world_size = torch.cuda.device_count()
    mp.set_start_method('forkserver')
    mp.spawn(dist_main, args=(world_size, args), nprocs=world_size, join=True)


if __name__ == '__main__':
    main()
