import sys
import copy
import glob

import logging
import math
import socket
import shutil
import pathlib
import argparse
import itertools
import setproctitle
from datetime import datetime
from functools import partial

import os
import torch
import numpy as np
from torchvision import datasets, models
from torchvision.transforms import transforms
from tqdm import tqdm, trange
import networkx as nx
import more_itertools
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.utils.tensorboard import SummaryWriter
from networkx.algorithms.dag import topological_sort
from torch.utils.data import DataLoader, DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP

from reactnet import reactnet
sys.path.append('.')
from extract_cov import combine_cov
from kmeans import kmeans
import logging_util
from baseline_imagenet.utils.utils import Lighting

torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True


def make_graph(model):
    graph = nx.DiGraph()
    visited = set()

    def traverse(layer_name):
        visited.add(layer_name)
        for dep in model.dependent_layers(layer_name):
            graph.add_edge(dep, layer_name)
            if dep not in visited:
                traverse(dep)

    traverse('fc')
    return graph


def dataset_cov(train_loader, loader_instance, permute_indices, stride, get_patches, dim, sample_portion=1.0):
    dim = dim // stride
    count = torch.zeros((), device='cuda', dtype=torch.long)
    mean = torch.zeros(dim, device='cuda', dtype=torch.float64)
    cov = torch.zeros(dim, dim, device='cuda', dtype=torch.float64)
    iters = math.ceil(len(train_loader) * sample_portion)
    if loader_instance is None:
        loader_instance = iter(train_loader)
    for _ in trange(iters):
        try:
            X, y = next(loader_instance)
        except StopIteration:
            loader_instance = iter(train_loader)
            X, y = next(loader_instance)
        X = X.cuda(non_blocking=True).float()
        torch.cuda.synchronize()

        patches = get_patches(X)  # (dim, batch*height*width)
        if permute_indices is not None:
            patches = patches[permute_indices]
            patches = F.avg_pool1d(patches.T, stride).T

        other_count = torch.tensor(patches.size(1), dtype=torch.long)
        other_mean = patches.mean(1)
        other_cov = torch.cov(patches)

        if count == 0:
            count.copy_(other_count)
            mean.copy_(other_mean)
            cov.copy_(other_cov)
        else:
            combine_cov(
                count, mean, cov,
                other_count, other_mean, other_cov
            )

    other_count = count.new_empty(count.shape)
    other_mean = mean.new_empty(mean.shape)
    other_cov = cov.new_empty(cov.shape)

    rank, world_size = dist.get_rank(), dist.get_world_size()
    gap = 2
    for stage in range(math.ceil(math.log2(world_size))):
        dst = rank // 2 * gap
        src = dst + gap // 2
        if rank % 2 == 0:
            if src < world_size:
                dist.recv(other_count, src)
                dist.recv(other_mean, src)
                dist.recv(other_cov, src)
                combine_cov(
                    count, mean, cov,
                    other_count, other_mean, other_cov
                )
        else:
            dist.send(count, dst)
            dist.send(mean, dst)
            dist.send(cov, dst)
            break
        rank //= 2
        gap *= 2

    dist.broadcast(count, 0)
    dist.broadcast(mean, 0)
    dist.broadcast(cov, 0)

    return cov.float(), loader_instance


class Sign(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x: torch.Tensor):
        ctx.save_for_backward(x)
        return torch.sign(x)

    @staticmethod
    def backward(ctx, g):
        x, = ctx.saved_tensors
        return g * torch.clamp(2 * (1 - torch.abs(x)), min=0.0)


@torch.jit.script
def quantization_loss(q, M, w_hat, lamba: float = 1e-5):
    q_hat = M @ q
    loss = ((q_hat - w_hat) ** 2).sum(0)
    loss += lamba * q.norm(2, dim=0)
    return loss


# knowledge distillation loss
@torch.jit.script
def distillation(logits, labels, teacher_scores, T: float, alpha: float):
    # distillation loss + classification loss
    # y: student
    # labels: hard label
    # teacher_scores: soft label
    task_loss = F.cross_entropy(logits, labels, reduction='none')
    teacher_loss = F.kl_div(
        F.log_softmax(logits / T, dim=-1),
        F.log_softmax(teacher_scores / T, dim=-1),
        reduction='none',
        log_target=True
    ).sum(1)
    return task_loss, task_loss * (1 - alpha) + teacher_loss * (2 * T ** 2 + alpha)


@torch.jit.script
def fused_clip(max_norm: float, total_norm):
    clip_coef = max_norm / (total_norm + 1e-6)
    clip_coef = torch.clip(clip_coef, 0.0, 1.0)
    return clip_coef


@torch.no_grad()
def clip_grad_norm_(
        parameters, max_norm: float, norm_type: float = 2.0) -> torch.Tensor:
    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]
    parameters = [p for p in parameters if p.grad is not None]
    if len(parameters) == 0:
        return torch.tensor(0.)
    if norm_type == float('inf'):
        norms = [p.grad.detach().abs().max() for p in parameters]
        total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms))
    else:
        total_norm = torch.norm(torch.stack([
            torch.norm(p.grad.detach(), norm_type) for p in parameters]), norm_type)
    clip_coef = fused_clip(max_norm, total_norm)
    for p in parameters:
        p.grad.detach().mul_(clip_coef)
    return total_norm


@torch.enable_grad()
def train_model(model, opt, teacher, train_loader, n_steps,
                T=20.0, alpha=0.7, lr_decay=False):
    teacher.eval()
    model = DDP(model)
    model.train()
    steps = 0
    orig_lr = opt.param_groups[0]['lr']
    params = sum([g['params'] for g in opt.param_groups], [])

    while True:
        count = torch.tensor(0, device='cuda')
        task_loss_accum = torch.tensor(0.0, device='cuda')
        loss_accum = torch.tensor(0.0, device='cuda')
        correct_count = torch.tensor(0, device='cuda')
        for X, y in tqdm(train_loader, desc="Train"):
            X = X.cuda(non_blocking=True).float()
            y = y.cuda(non_blocking=True).flatten()
            torch.cuda.synchronize()
            model.zero_grad()

            logits = model(X)
            with torch.no_grad():
                teacher_scores = teacher(X)
            task_loss, loss = distillation(logits, y, teacher_scores, T, alpha)
            loss.mean().backward()

            # apply EmpCov to the first layer
            #w = cur_w.data.view(cur_w.shape[0], -1).T  # (in, out)
            #cur_w.grad.add_((inv_sigma @ w).T.reshape_as(cur_w))

            # clip_grad_norm_(params, 2.0)  # clip gradient
            if lr_decay:
                for g in opt.param_groups:
                    r = steps / n_steps
                    lr = orig_lr * 0.9 * (1 - r) + orig_lr * 0.1
                    g['lr'] = lr

            opt.step()
            count += X.shape[0]
            task_loss_accum += task_loss.sum().data
            loss_accum += loss.sum().data
            correct_count += (torch.argmax(logits, dim=1) == y).sum()

            steps += 1
            if steps > n_steps:
                return
        dist.all_reduce(task_loss_accum)
        dist.all_reduce(loss_accum)
        dist.all_reduce(correct_count)
        dist.all_reduce(count)
        tqdm.write(f'task{(task_loss_accum / count).item()} / dist{(loss_accum / count).item()}')
        accuracy = (correct_count / count).item()
        tqdm.write(f'{accuracy:.4f}')


class QuantizationModel(torch.nn.Module):
    def __init__(self, model, log_s, V):
        super(QuantizationModel, self).__init__()
        self.model = model
        self.V = torch.nn.Parameter(V)
        self.log_s = torch.nn.Parameter(log_s)

    def forward(self, x):
        return self.model(x)


def padded_stack(ws):
    max_n_out = max(w.shape[0] for w in ws)
    sum_n_in = sum(w.shape[1] for w in ws)
    stacked = torch.zeros((max_n_out, sum_n_in), device='cuda')
    offset = 0
    for w in ws:
        n_out = w.shape[0]
        n_in = w.shape[1]
        stacked[:n_out, offset:offset+n_in] = w
        offset += n_in
    return stacked


@torch.enable_grad()
def quantize_layer(model, teacher, train_loader, val_loader, n_steps,
                   bn_params, other_params, cur_layers,
                   distance,
                   log_s, V, orig_log_s=None, orig_V=None,
                   permute_indices=None,
                   add_distance=True,
                   T=20.0, distill_alpha=0.7, lr=1.25e-3 / 4,
                   max_cov_dim=256, weight_decay=1e-5,
                   distance_coeff=1.0, lr_decay=True,
                   findq_nsamples=20):

    # freeze all layers
    original_requires_grad = dict()
    for module_name, module in model.named_modules():
        if isinstance(module, torch.nn.BatchNorm2d):
            continue
        for name, p in module.named_parameters(recurse=False):
            original_requires_grad[module_name + '.' + name] = p.requires_grad
            p.requires_grad = False
            p.grad = None

    flat_layers_list = []
    orig_weight_params = []
    cur_params = [*other_params]

    ws = []
    for i, (layers, permute) in enumerate(zip(cur_layers, permute_indices)):
        w = torch.cat([layer.weight.data for layer in layers], dim=0)
        w = w.view(w.shape[0], -1)  # (output, input)
        w_perm = w
        if permute is not None:
            assert w.shape[1] == permute.shape[0]
            w_perm = w_perm[:, permute]

        w_pooled = w_perm
        if w.shape[1] > max_cov_dim:
            stride = w.shape[1] // max_cov_dim
            stride = next_power_of_2(stride)
            assert w.shape[1] % stride == 0
            w_pooled = F.avg_pool1d(w[None], stride, stride)[0]

        ws.append(w_pooled)

        if i == len(cur_layers) - 1 and orig_log_s is not None:
            # (in, out)
            alpha, scores = find_q(w.T, torch.diag(torch.exp(orig_log_s)) @ orig_V.T, nsamples=findq_nsamples)

        for layer in layers:
            flat_layers_list.append(layer)
            orig_weight_params.append(layer.set_mode('quantize'))
            cur_params.extend([layer.alpha, layer.scores])

        # apply quantized layer weight
        if i == len(cur_layers) - 1 and orig_log_s is not None:
            if len(layers) == 2:
                half_out = alpha.shape[1] // 2
                layers[0].alpha.data.copy_(alpha[:, :half_out].reshape_as(layers[0].alpha))
                layers[0].scores.data.copy_(scores[:, :half_out].T.reshape_as(layers[0].scores))
                layers[1].alpha.data.copy_(alpha[:, half_out:].reshape_as(layers[0].alpha))
                layers[1].scores.data.copy_(scores[:, half_out:].T.reshape_as(layers[0].scores))
            else:
                layers[0].alpha.data.copy_(alpha.reshape_as(layers[0].alpha))
                layers[0].scores.data.copy_(scores.T.reshape_as(layers[0].scores))

    ws = padded_stack(ws)

    det = log_s.data.sum()

    orig_model = model
    if add_distance:
        model = QuantizationModel(model, log_s, V)
        log_s, V = model.log_s, model.V

    # unfreeze all trained params
    for param in cur_params:
        param.requires_grad = True

    opt = torch.optim.Adam(
        [{'params': bn_params},
         {'params': cur_params, 'weight_decay': weight_decay},
         {'params': [log_s, V]}],
        lr=lr
    )
    orig_lr = lr

    test_model(orig_model, val_loader)

    teacher.eval()
    model = DDP(model)
    steps = 0
    epoch = 0
    while True:
        epoch += 1
        count = torch.tensor(0, device='cuda')
        dist_loss_accum = torch.tensor(0.0, device='cuda')
        task_loss_accum = torch.tensor(0.0, device='cuda')
        loss_accum = torch.tensor(0.0, device='cuda')
        correct_count = torch.tensor(0, device='cuda')

        model.train()
        with tqdm(train_loader, desc=f"Train e{epoch}") as pbar:
            for X, y in pbar:
                X = X.cuda(non_blocking=True).float()
                y = y.cuda(non_blocking=True).flatten()
                torch.cuda.synchronize()
                model.zero_grad()

                logits = model(X)
                with torch.no_grad():
                    teacher_scores = teacher(X)
                task_loss, loss = distillation(logits, y, teacher_scores, T, distill_alpha)
                dist_loss = loss

                # add distance metric
                if add_distance:
                    wqs = []
                    for layers, permute in zip(cur_layers, permute_indices):
                        alpha = torch.cat([l.alpha for l in layers], dim=0)
                        scores = torch.cat([l.scores for l in layers], dim=0)
                        wq = alpha * scores
                        wq = wq.view(wq.shape[0], -1)

                        if permute is not None:
                            assert wq.shape[1] == permute.shape[0]
                            wq = wq[:, permute]

                        if wq.shape[1] > max_cov_dim:
                            stride = wq.shape[1] // max_cov_dim
                            stride = next_power_of_2(stride)
                            assert wq.shape[1] % stride == 0
                            wq = F.avg_pool1d(wq[None], stride, stride)[0]

                        wqs.append(wq)
                    wqs = padded_stack(wqs)
                    distance_metric = distance(log_s, V, wqs, ws)
                    loss = loss + distance_coeff * distance_metric

                    # add regularization
                    VVt = (V @ V.T)
                    VVt.diagonal().sub_(1.0)
                    ortho_reg = VVt.norm(2) ** 2  # squared frobenius norm
                    det_reg = ((log_s.sum() - det) / det)**2
                    reg = ortho_reg * 0.1 + det_reg
                    loss += reg * 10
                    pbar.set_postfix({
                        'ortho_reg': ortho_reg.item(),
                        'det_reg': det_reg.item(),
                        'dist': distance_metric.mean().item()
                    })

                # backward pass
                loss.mean().backward()

                # optimizer step
                opt.step()
                count += X.shape[0]
                task_loss_accum += task_loss.sum().data
                dist_loss_accum += dist_loss.sum().data
                loss_accum += loss.sum().data
                correct_count += (torch.argmax(logits, dim=1) == y).sum()

                steps += 1

                if lr_decay:
                    for g in opt.param_groups:
                        r = steps / n_steps
                        lr = orig_lr * 0.9 * (1 - r) + orig_lr * 0.1
                        g['lr'] = lr
                if steps > n_steps:
                    break
        if steps > n_steps:
            break

        dist.all_reduce(task_loss_accum)
        dist.all_reduce(dist_loss_accum)
        dist.all_reduce(loss_accum)
        dist.all_reduce(correct_count)
        dist.all_reduce(count)

        tqdm.write(f'task{(task_loss_accum / count).item()} / '
                   f'dist{(dist_loss_accum / count).item()} / '
                   f'reg{(loss_accum / count).item()}')
        accuracy = (correct_count / count).item()
        tqdm.write(f'{accuracy:.4f}')

        test_model(orig_model, val_loader)

    # restore layer mode
    for layer, param in zip(flat_layers_list, orig_weight_params):
        layer.set_mode('original', param)

    # restore frozen state before
    params = dict(orig_model.named_parameters())
    for name, requires_grad in original_requires_grad.items():
        params[name].requires_grad = requires_grad
        if not requires_grad:
            params[name].grad = None

    return log_s, V


def weight_distance(log_s, V, wq, w):
    M = torch.diag(torch.exp(log_s)) @ V.T
    w_hat = M @ w.T
    loss = quantization_loss(wq.T, M, w_hat)
    lsum = loss.mean()
    return lsum


def expand_S_V(old_log_s: torch.Tensor, old_V: torch.Tensor,
               new_log_s: torch.Tensor, new_V: torch.Tensor):
    dim = new_V.shape[0]
    old_dim = old_V.shape[0]
    new_dim = old_dim + dim

    large_log_s = old_log_s.new_zeros((new_dim,))
    large_log_s[:old_dim] = old_log_s
    large_log_s[-dim:] = new_log_s

    large_V = old_V.new_zeros((new_dim, new_dim))
    large_V[:old_dim, :old_dim] = old_V
    large_V[-dim:, -dim:] = new_V

    return large_log_s, large_V


def eigh(cov):
    try:
        return torch.linalg.eigh(cov)
    except torch._C._LinAlgError:
        return torch.linalg.eigh(cov)


def next_power_of_2(x):
    return 1 if x == 0 else 2**math.ceil(math.log2(x))


@torch.no_grad()
def reorder_channels(train_loader, loader_instance, get_patches, dim, num_clusters, total_size=2**25):
    features = None
    offset = 0
    steps = 0
    if loader_instance is None:
        loader_instance = iter(train_loader)
    while True:
        try:
            X, y = next(loader_instance)
        except StopIteration:
            loader_instance = iter(train_loader)
            X, y = next(loader_instance)
        X = X.cuda(non_blocking=True).float()
        torch.cuda.synchronize()

        patches = get_patches(X)  # (dim, batch*height*width)

        if features is None:
            size = min(total_size // dim, len(train_loader) * patches.shape[1])
            features = torch.zeros([dim, size], device='cuda')

        print(f"{offset}/{features.shape[1]}")

        n = min(patches.shape[1], features.shape[1] - offset)
        features[:, offset:offset+n] = patches[:, :n]
        offset += n
        steps += 1

        if offset == features.shape[1]:
            break

    print("Running Kmeans")
    cluster_ids, _ = kmeans(
        X=features, num_clusters=num_clusters, distance='euclidean', device=torch.device('cuda')
    )
    sort_indices = torch.argsort(cluster_ids)
    assert sort_indices.shape[0] == dim

    return sort_indices, loader_instance


@torch.enable_grad()
def find_q(w, M, nsteps=2000, nsamples=40, xnor_net=False):
    if xnor_net:
        alpha = w.abs().mean(0, keepdim=True)
        q = alpha * torch.sign(w)
        return q

    w_hat = M @ w

    alpha = torch.nn.Parameter(w.abs().mean(0, keepdim=True))
    w = torch.nn.Parameter(w / alpha)

    initial_alpha_scale = alpha.data.norm(2).clone()
    print(f"initial_alpha={initial_alpha_scale}")

    opt = torch.optim.SGD([w, alpha], lr=1e-4, momentum=0.9, nesterov=True)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=nsteps)

    loss = quantization_loss(
        alpha * Sign.apply(w),
        M, w_hat
    )
    print(f"Initial loss={loss.mean().item()}")

    best = torch.tensor(10. ** 9, device='cuda')
    best_w = copy.deepcopy(w.data)
    best_alpha = copy.deepcopy(alpha.data)
    for _ in range(nsamples):
        for step in range(nsteps):
            opt.zero_grad()

            loss = quantization_loss(
                alpha * Sign.apply(w),
                M, w_hat
            )
            lsum = loss.mean()
            lsum.backward()

            # update best
            cond = torch.le(lsum, best)
            best.data.copy_(torch.where(cond, lsum, best))
            best_w.data.copy_(torch.where(cond, w.data, best_w.data))
            best_alpha.data.copy_(torch.where(cond, alpha.data, best_alpha.data))

            opt.step()
            scheduler.step()
        w.data.add_(torch.randn_like(w) * 0.01)
        print("best=", best.item())

    best_list = [torch.zeros_like(best) for _ in range(dist.get_world_size())]
    dist.all_gather(best_list, best)
    best_idx = torch.argmin(torch.stack(best_list), dim=0).item()

    dist.broadcast(best_alpha, best_idx)
    dist.broadcast(best_w, best_idx)

    print(f"final best={best}, best_alpha={best_alpha.norm(2)}")
    return best_alpha, best_w


@torch.no_grad()
def quantize(model, teacher_model, train_loader, svd_train_loader, val_loader,
             block_iters, finetune_epochs, world_size, rank, args,
             xnor_net=False, skip_until=None,
             block_size=7, sample_portion=1.0, log_dir=None,
             finetune_lr=1.25e-3 / 4, weight_decay=1e-5,
             quantize_lr=1.25e-3 / 4,
             max_cov_dim=256, distance_coeff=1.0,
             exclude_layers=None,
             no_distance=False, resume_sv=None):
    if exclude_layers is None:
        exclude_layers = ['fc']
    graph = make_graph(model)
    modules = dict(model.named_modules())
    weight_params = sum([list(m.parameters(recurse=False)) for m in model.modules()
                         if not isinstance(m, torch.nn.BatchNorm2d)], [])
    bn_params = sum([list(m.parameters(recurse=False)) for m in model.modules()
                     if isinstance(m, torch.nn.BatchNorm2d)], [])
    print(f"{len(weight_params)} weight params, {len(bn_params)} bn params")

    teacher = teacher_model
    total_steps = math.ceil(len(train_loader) * finetune_epochs)
    n_steps = total_steps // block_iters

    # for layer_name in topological_sort(graph):
    layers = list(topological_sort(graph))
    print(f"Exclude layers: {exclude_layers}")
    layers = [l for l in layers if l not in exclude_layers]
    groups = [layers[max(0, i-block_size+1):i+1] for i in range(len(layers))]
    print(f"Layer groups: {groups}")

    cov_iter = None

    V = torch.zeros((0, 0), device='cuda')
    log_s = torch.zeros((0,), device='cuda')
    permute_indices_group = []
    group_dims = []

    for group_idx, layer_group in enumerate(groups):
        print("Quantizing", layer_group)

        group_layers = []
        for layer_name in layer_group:
            if layer_name.endswith('_down'):
                cur_layers = [modules[layer_name + '1'], modules[layer_name + '2']]
            else:
                cur_layers = [modules[layer_name]]
            group_layers.append(cur_layers)

        layer_name = layer_group[-1]

        # skip quantization until specified layer is reached
        if layer_name == skip_until:
            skip_until = None
            if resume_sv is not None:
                sv = torch.load(resume_sv, map_location='cuda')
                log_s = sv['log_s']
                V = sv['V']

        if layer_name.endswith('_down'):
            cur_layers = [modules[layer_name + '1'], modules[layer_name + '2']]
        else:
            cur_layers = [modules[layer_name]]

        # determine earlier trainable layers
        _, idx, name = layer_name.split('.')
        idx = int(idx)
        if "3x3" in layer_name:
            if idx == 1:
                other_modules = [
                    modules["feature.0.conv1"],
                    #modules["feature.0.bn1"],
                    modules["feature.1.move11"]
                ]
            else:
                other_modules = [
                    #*([f"feature.{idx-1}.bn2_1",
                    #   f"feature.{idx-1}.bn2_2"]
                    #  if f"feature.{idx-1}.bn2_2" in modules
                    #  else [f"feature.{idx-1}.bn2"]),
                    modules[f"feature.{idx-1}.move22"],
                    modules[f"feature.{idx-1}.prelu2"],
                    modules[f"feature.{idx-1}.move23"],
                    modules[f"feature.{idx}.move11"]
                ]
        else:
            other_modules = [
                #modules[f"feature.{idx}.bn1"],
                modules[f"feature.{idx}.move12"],
                modules[f"feature.{idx}.prelu1"],
                modules[f"feature.{idx}.move13"],
                modules[f"feature.{idx}.move21"]
            ]
        print(f"Other modules = {other_modules}")
        other_params = []
        for module in other_modules:
            other_params.extend(module.parameters())

        cur_params = [layer.weight for layer in cur_layers]
        groups = cur_layers[0].groups
        get_patches = partial(extract_patches, model, cur_layers[0])
        dim = np.prod(cur_layers[0].weight.shape[1:])
        print(f"dim = {dim}")

        # freeze layer
        for p in cur_params:
            p.requires_grad = False
            p.grad = None

        if dim > max_cov_dim:
            stride = dim // max_cov_dim
            stride = next_power_of_2(stride)
            print(f"stride = {stride}")
            assert dim % stride == 0
        else:
            stride = 1

        permute_indices = None
        if stride > 1:
            num_clusters = dim // stride
            permute_indices, cov_iter = reorder_channels(
                svd_train_loader, cov_iter, get_patches, dim, num_clusters)
            dist.broadcast(permute_indices, 0)
        permute_indices_group.append(permute_indices)

        # remove first element in group
        if len(permute_indices_group) > args.block_size:
            permute_indices_group = permute_indices_group[1:]

        # calculate input covariance
        orig_log_s = orig_V = None
        if args.find_q:
            model.eval()
            cov_orig, cov_iter = dataset_cov(
                svd_train_loader, cov_iter,
                None, 1,
                get_patches, dim, sample_portion)
            s, orig_V = eigh(cov_orig)
            orig_log_s = torch.log(torch.clip(s, min=0.0) + 1e-6)

        # grouped input covariance
        cov, cov_iter = dataset_cov(
            svd_train_loader, cov_iter,
            permute_indices, stride,
            get_patches, dim, sample_portion)
        print("Cov finite=", torch.all(torch.isfinite(cov)))

        s, new_V = eigh(cov)
        new_log_s = torch.log(torch.clip(s, min=0.0) + 1e-6)

        # remove first layer
        group_dims.append(dim // stride)
        if len(group_dims) > args.block_size:
            first_dim = group_dims[0]
            group_dims = group_dims[1:]
            log_s = log_s[first_dim:]
            V = V[first_dim:, first_dim:]

        # expand covariance matrix
        log_s, V = expand_S_V(log_s, V, new_log_s, new_V)

        print(f"V matrix size = {V.shape}")

        # skip quantization until specified layer is reached
        if skip_until is not None:
            continue

        # quantize layer
        print("Finding quantized values")
        log_s, V = quantize_layer(
            model, teacher, train_loader, val_loader, n_steps,
            bn_params=bn_params, other_params=other_params,
            cur_layers=group_layers,
            distance=weight_distance, log_s=log_s, V=V,
            orig_log_s=orig_log_s, orig_V=orig_V,
            lr=quantize_lr,
            permute_indices=permute_indices_group,
            add_distance=(not no_distance),
            max_cov_dim=max_cov_dim, weight_decay=0,
            distance_coeff=distance_coeff,
            lr_decay=args.lr_decay,
            findq_nsamples=args.findq_nsamples
        )

        test_model(model, val_loader)

        print("Fine-tuning")

        weight_params = [p for p in weight_params if p.requires_grad]
        opt = torch.optim.Adam(
            [{'params': bn_params},
             {'params': weight_params, 'weight_decay': weight_decay}],
            lr=finetune_lr
        )

        train_model(model, opt, teacher, train_loader,
                    n_steps=n_steps, lr_decay=args.lr_decay)
        test_model(model, val_loader)

        if rank == 0:
            print("Saving to ", log_dir / f"block-{group_idx}.pth")
            torch.save(model.state_dict(), log_dir / f"block-{group_idx}.pth")
            print("Saving s, V to ", log_dir / f"sV-{group_idx}.pth")
            torch.save({'log_s': log_s, 'V': V}, log_dir / f"sV-{group_idx}.pth")


@torch.no_grad()
def extract_patches(model, layer, X):
    model.eval()
    X = model.inputs_for(layer, X)
    if isinstance(layer, torch.nn.Conv2d):
        X = F.unfold(
            X,
            kernel_size=layer.kernel_size,
            dilation=layer.dilation,
            padding=layer.padding,
            stride=layer.stride,
        )  # (batch, ch*ks*ks, height*width)
        X = X.permute(1, 0, 2)  # (ch*ks*ks, batch, height*width)
        X = X.reshape(X.shape[0], -1)  # (dim, batch*height*width)
    return X


def train_acc_model(model, train_loader: DataLoader):
    model = DDP(model)
    model.train()
    print("Train Acc")
    count = torch.tensor(0, device='cuda')
    loss_accum = torch.tensor(0.0, device='cuda')
    correct_count = torch.tensor(0, device='cuda')
    criterion = torch.nn.CrossEntropyLoss(reduction='sum')
    with torch.no_grad():
        for X, y in train_loader:
            X = X.cuda(non_blocking=True).float()
            y = y.cuda(non_blocking=True).flatten()
            torch.cuda.synchronize()
            logits = model(X)
            loss = criterion(logits, y)
            count += X.shape[0]
            loss_accum += loss.data
            correct_count += (torch.argmax(logits, dim=1) == y).sum()
        dist.all_reduce(loss_accum)
        dist.all_reduce(correct_count)
        dist.all_reduce(count)
        loss = (loss_accum / count).item()
        acc = (correct_count / count).item()
        print(f'{loss}')
        print(f'{acc:.4f}')
    return loss, acc


def test_model(model, val_loader: DataLoader):
    model = DDP(model)
    model.eval()
    print("Test")
    count = torch.tensor(0, device='cuda')
    loss_accum = torch.tensor(0.0, device='cuda')
    correct_count = torch.tensor(0, device='cuda')
    criterion = torch.nn.CrossEntropyLoss(reduction='sum')
    with torch.no_grad():
        for X, y in tqdm(val_loader, desc="Test"):
            X = X.cuda(non_blocking=True).float()
            y = y.cuda(non_blocking=True).flatten()
            torch.cuda.synchronize()
            logits = model(X)
            loss = criterion(logits, y)
            count += X.shape[0]
            loss_accum += loss.data
            correct_count += (torch.argmax(logits, dim=1) == y).sum()
        print(correct_count, count)
        dist.all_reduce(loss_accum)
        dist.all_reduce(correct_count)
        dist.all_reduce(count)
        loss = (loss_accum / count).item()
        acc = (correct_count / count).item()
        print(f'{loss}')
        print(f'{acc:.4f}')
    return loss, acc


def test_model_orig(model, val_loader: DataLoader):
    model.eval()
    print("Test")
    count = torch.tensor(0, device='cuda')
    loss_accum = torch.tensor(0.0, device='cuda')
    correct_count = torch.tensor(0, device='cuda')
    criterion = torch.nn.CrossEntropyLoss(reduction='sum')
    with torch.no_grad():
        for X, y in val_loader:
            X = X.cuda(non_blocking=True).float()
            y = y.cuda(non_blocking=True).flatten()
            torch.cuda.synchronize()
            logits = model(X)
            loss = criterion(logits, y)
            count += X.shape[0]
            loss_accum += loss.data
            correct_count += (torch.argmax(logits, dim=1) == y).sum()
        loss = (loss_accum / count).item()
        acc = (correct_count / count).item()
        print(f'{loss}')
        print(f'{acc:.4f}')
    return loss, acc


def load_model(path):
    model = reactnet(num_classes=1000).cuda()
    checkpoint = torch.load(path, map_location='cpu')
    if 'epoch' in checkpoint:
        print(checkpoint['epoch'], checkpoint['best_top1_acc'])
        state_dict = {k.replace('module.', ''): v for k, v in checkpoint['state_dict'].items()}
    else:
        state_dict = checkpoint
    print(model.load_state_dict(state_dict, strict=False))
    return model


def torch_imagenet(args):
    data = "/d1/dataset/ILSVRC2012/"
    traindir = os.path.join(data, 'train')
    valdir = os.path.join(data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    # data augmentation
    crop_scale = 0.08
    lighting_param = 0.1
    train_transforms = transforms.Compose([
        transforms.RandomResizedCrop(224, scale=(crop_scale, 1.0)),
        Lighting(lighting_param),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize])

    train_dataset = datasets.ImageFolder(
        traindir,
        transform=train_transforms)

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True, prefetch_factor=8)

    # load validation data
    val_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(valdir, transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    return train_loader, train_loader, val_loader


def run(rank, world_size, args):
    writer = SummaryWriter(log_dir=args.log_dir)

    numba_logger = logging.getLogger('numba')
    numba_logger.setLevel(logging.WARNING)

    logging_util.setup_logging(
        pathlib.Path(args.log_dir) / f'log_rank{rank}.txt',
        log_to_screen=(rank == 1)
    )

    print(f"Executing {__file__}.")
    print("===== Args =====")
    for name, value in vars(args).items():
        print(f"   {name}: {value}")
    print("================")

    print(f"Using GPU {rank}")
    torch.cuda.set_device(f"cuda:{rank}")

    # construct model
    model = load_model("/d1/xxx/DBQ/imagenet-step1-bn-adjusted.pth")
    if len(args.exclude_layers) > 0:
        if args.exclude_layers[0] == 'feature.12.binary_pw_down':
            model.feature[12].quantize_down = False
        else:
            assert False
    teacher_model = copy.deepcopy(model)
    if args.resume is not None:
        print(f"Resuming from {args.resume}")
        model.load_state_dict(torch.load(args.resume, map_location='cuda'))
    # print(model)
    dist.barrier()

    # load dataset
    print("Loading dataset...")
    if args.torch_imgnet:
        train_loader, svd_train_loader, val_loader = torch_imagenet(args)
    else:
        train_loader, svd_train_loader, val_loader = imagenet_dataset(args)
    print("Loaded dataset.")

    # evaluated fp model
    # train_acc_model(model, train_loader)
    test_model(model, val_loader)

    # torch.save(model.state_dict(), "/d1/xxx/DBQ/imagenet-step1-bn-adjusted.pth")

    # quantize model
    quantize(model, teacher_model,
             train_loader, svd_train_loader, val_loader,
             args=args,
             block_iters=args.block_iters,
             xnor_net=args.xnor_net,
             skip_until=args.skip_until,
             resume_sv=args.resume_sv,
             finetune_epochs=args.finetune_epochs,
             block_size=args.block_size,
             sample_portion=args.sample_portion,
             log_dir=pathlib.Path(writer.log_dir),
             finetune_lr=args.finetune_lr,
             quantize_lr=args.quantize_lr,
             weight_decay=args.weight_decay,
             distance_coeff=args.distance_coeff,
             max_cov_dim=args.max_cov_dim,
             exclude_layers=args.exclude_layers,
             rank=rank, world_size=world_size)

    # evaluate quantized model
    train_acc_model(model, train_loader)
    test_model(model, val_loader)

    if rank == 0:
        torch.save(model.state_dict(), args.save_path)


def imagenet_dataset(args):
    import imagenet
    train_loader = imagenet.create_train_loader(
        args.train_path,
        args.workers, args.batch_size,
        distributed=True, in_memory=True,
        portion=args.portion
    )
    val_loader = imagenet.create_val_loader(
        args.test_path,
        args.workers, args.batch_size, 224, True
    )
    return train_loader, train_loader, val_loader


def dist_main(rank, world_size, args):
    setproctitle.setproctitle(f"proc{rank}")
    os.environ['MASTER_ADDR'] = 'localhost'
    if 'MASTER_PORT' not in os.environ:
        os.environ['MASTER_PORT'] = '12355'

    # initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    run(rank, world_size, args)
    dist.destroy_process_group()


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch_size', type=int, default=256, help='batch size')
    parser.add_argument('-j', '--workers', default=20, type=int, metavar='N',
                        help='number of data loading workers (default: 4)')
    parser.add_argument('--data', default="/w14/dataset/ILSVRC2012/", type=str)
    parser.add_argument('--xnor_net', default=False, action='store_true')
    parser.add_argument('--skip_until', default=None, type=str)
    parser.add_argument('--block_iters', default=1, type=int)
    parser.add_argument('--finetune_epochs', default=5, type=float)
    parser.add_argument('--block_size', default=4, type=int)
    parser.add_argument('--sample_portion', default=1.0, type=float)
    parser.add_argument('--finetune_lr', default=1.25e-3 / 4, type=float)
    parser.add_argument('--quantize_lr', default=1.25e-3 / 4, type=float)
    parser.add_argument('--portion', default=None, type=int)
    parser.add_argument('--weight_decay', default=0.0, type=float)
    parser.add_argument('--distance_coeff', default=1.0, type=float)
    parser.add_argument('--name', required=True, type=str)
    parser.add_argument('--max_cov_dim', default=256, type=int)
    parser.add_argument('--train_path', type=str, default='/v4/xxx/imagenet_ffcv_train.ffcv')
    parser.add_argument('--test_path', type=str, default='/v4/xxx/imagenet_ffcv_val.ffcv')
    parser.add_argument('--exclude_layers', type=str, nargs='*', default=[])
    parser.add_argument('--resume', type=str, default=None)
    parser.add_argument('--resume-sv', type=str, default=None)
    parser.add_argument('--find_q', default=False, action='store_true')
    parser.add_argument('--findq_nsamples', default=20, type=int)
    parser.add_argument('--lr_decay', default=False, action='store_true')
    parser.add_argument('--torch_imgnet', default=False, action='store_true')
    args = parser.parse_args()

    current_time = datetime.now().strftime('%b%d_%H-%M-%S-%f')
    args.log_dir = pathlib.Path('runs') / f"reactnet_imagenet_{args.name}_{current_time}_{socket.gethostname()}"
    args.save_path = args.log_dir / f"model-quantized.pt"
    print("Will save to ", args.save_path)

    source_copy_dir = args.log_dir / 'source'
    source_copy_dir.mkdir(parents=True, exist_ok=False)
    current_dir = pathlib.Path(__file__).parent
    py_files = glob.glob(str(current_dir / '*.py'))  # + glob.glob(str(current_dir / '**' / '*.py'))
    for f in py_files:
        relpath = pathlib.Path(f).relative_to(current_dir)
        print(f"Copy {relpath} to {source_copy_dir}")
        target = source_copy_dir / relpath
        target.parent.mkdir(parents=True, exist_ok=True)
        shutil.copy(f, target)

    world_size = torch.cuda.device_count()
    mp.set_start_method('forkserver')
    mp.spawn(dist_main, args=(world_size, args), nprocs=world_size, join=True)


if __name__ == '__main__':
    main()
