import glob
import pathlib
import shutil

import mobilenetv2

from torchvision.datasets import CIFAR100
from torchvision import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from derivs import q, f_L_batched, v_dgdx_product_batched, v_dgdx_product_batched_2
from torch.utils.tensorboard import SummaryWriter
from utils import clip_grad_norm_, ScaledWeightConv2d


LAMDA = 0.1


@torch.jit.script
def covar(tensor, rowvar: bool = True, bias: bool = False):
    tensor = tensor if rowvar else tensor.transpose(-1, -2)
    tensor -= tensor.mean(dim=-1, keepdim=True)
    factor = 1 / (tensor.shape[-1] - int(not bool(bias)))
    return factor * (tensor @ tensor.transpose(-1, -2).conj())


def plot_covariance(s, V):
    plt.plot(s.cpu())
    plt.show()

    # plot principal components
    if s.size(0) == 27:
        plt.figure(figsize=(10, 10), facecolor=(1, 1, 1))
        for i, v in enumerate(V.t()):
            plt.subplot(6, 5, i + 1)
            plt.imshow(v.reshape(3, 3, 3).cpu() + 0.5)
            plt.axis('off')
        plt.show()


@torch.enable_grad()
def quantizer_func(w, wq_score_init, M, state, layer_name, writer, step,
                   lamda=0.1, nsamples=4, niter=50, init_niter=1000):
    if state is None:
        tqdm.write(f"initializing optimizer for {layer_name}")
        wq_score = nn.Parameter(wq_score_init)
        opt = torch.optim.SGD([wq_score], lr=0.01)  #, momentum=0.9, nesterov=True)
        state = (wq_score, opt)
        niter = init_niter
    wq_score, opt = state

    y = f_L_batched(w, wq_score, M, lamda).mean(0)
    writer.add_scalar(f'q_loss/{layer_name}', y, step * 2)

    best_y = y
    best_wq_score = wq_score.data.clone()

    for _ in range(nsamples):
        for _ in range(niter):
            opt.zero_grad()
            y = f_L_batched(w, wq_score, M, lamda).mean(0)

            update = torch.le(y, best_y)
            best_y = torch.where(update, y, best_y)
            best_wq_score.copy_(torch.where(update, wq_score, best_wq_score))

            y.backward()
            opt.step()

        wq_score.data += torch.randn_like(wq_score.data) * 0.01

    wq_score.data.copy_(best_wq_score)
    y = f_L_batched(w, wq_score, M, lamda).mean(0)
    writer.add_scalar(f'q_loss/{layer_name}', y, step * 2 + 1)

    return wq_score.data, q(wq_score.data), state


@torch.no_grad()
def process_input_for_layer(inp, layer, layer_name, states, writer, step,
                            lamda=0.1, k=100, plot_cov=False):
    # tqdm.write(f"{layer}")

    patches = F.unfold(
        inp,
        kernel_size=layer.kernel_size,
        dilation=layer.dilation,
        padding=layer.padding,
        stride=layer.stride,
    ).permute(0, 2, 1)

    dim = patches.shape[-1] // layer.groups
    n_out = layer.weight.shape[0]

    patches = patches.reshape(
        -1, dim, layer.groups
    ).permute(0, 2, 1).reshape(-1, dim)
    cov = covar(patches, False)

    _, s, Vt = torch.svd_lowrank(cov, q=min(dim, k))
    writer.add_scalar(f'sv_ratio/{layer_name}', s[0] / s[-1], step)

    if plot_cov:
        plot_covariance(s, Vt)

    M = s[:, None] ** 0.5 * Vt.t()  # TODO: smoothing

    w_prime = layer.w_prime.reshape(n_out, -1)  # (n_out, n_in)
    wq_scores = layer.quant_scores.reshape(n_out, -1)

    # Optimize inner problem
    if layer_name not in states:
        states[layer_name] = {}

    wq_score, wq, states[layer_name]['opt'] = quantizer_func(
        w_prime, wq_scores, M,
        states[layer_name].get('opt', None),
        layer_name, writer, step,
        lamda=lamda
    )

    states[layer_name]['M'] = M
    layer.weight.data.copy_(wq.reshape(layer.weight.shape))

    # tqdm.write('')


def process_inputs(model, X, states, layer_names, writer, step, lamda=0.1):
    coro = model.coroutine(X)
    inp = None
    while True:
        try:
            inp, layer = coro.send(inp)

            if isinstance(layer, ScaledWeightConv2d):
                name = layer_names[layer]
                process_input_for_layer(inp, layer, name, states, writer, step, lamda=lamda)

        except StopIteration as ex:
            return ex.value


def modify_gradients(layer, name, M, writer, step, lamda=0.1):
    n_out = layer.weight.shape[0]

    w_prime = layer.w_prime.data.reshape(n_out, -1)  # (n_out, n_in)
    wq_scores = layer.quant_scores.reshape(n_out, -1)

    df_dy = layer.weight.grad.reshape(n_out, -1)
    layer.weight.grad = None

    grad = v_dgdx_product_batched_2(df_dy, w_prime, wq_scores, M, lamda)
    grad = grad.reshape(layer.weight.shape).clone()

    # Clip grad size
    grad_norm = grad.norm(2) + 1e-6
    grad.mul_(torch.minimum(grad_norm, grad.new_tensor(5.0)) / grad_norm)
    layer.w_prime.grad = grad

    writer.add_scalar(f"grad_norm/{name}", grad_norm, global_step=step)


def train(train_set, train_loader, test_set, val_loader):
    writer = SummaryWriter()

    # Copy sources
    source_copy_dir = pathlib.Path(writer.log_dir) / 'source'
    source_copy_dir.mkdir()
    current_dir = pathlib.Path(__file__).parent
    py_files = glob.glob(str(current_dir / '*.py'))  # + glob.glob(str(current_dir / '**' / '*.py'))
    for f in py_files:
        relpath = pathlib.Path(f).relative_to(current_dir)
        print(f"Copy {relpath} to {source_copy_dir}")
        target = source_copy_dir / relpath
        target.parent.mkdir(parents=True, exist_ok=True)
        shutil.copy(f, target)

    # make model
    model = mobilenetv2.MobileNetV2(
        num_classes=100,
        first_layer_type=ScaledWeightConv2d,
        expansion_layer=ScaledWeightConv2d,
        # depthwise_layer=ScaledWeightConv2d, # skip depthwise for now
        pointwise_layer=ScaledWeightConv2d,
        shortcut_layer=ScaledWeightConv2d,
        conv2_layer=ScaledWeightConv2d,
    ).cuda()

    model_name = f'/home/xxx/DBQ/mbnet_cifar100_state_dict.pt'
    state_dict = torch.load(model_name, map_location='cpu')
    model.load_state_dict(state_dict, strict=False)

    # initial quantization
    for name, m in model.named_modules():
        if isinstance(m, ScaledWeightConv2d):
            w = m.weight.data
            w = w.reshape(w.size(0), -1)  # (n_out, n_in)
            scales = torch.quantile(w.abs(), 0.75, 1, keepdim=True)
            w_prime = (w / scales).reshape(m.weight.shape)
            m.scales.data.copy_(scales[:, :, None, None])
            m.w_prime.data.copy_(w_prime)
            m.quant_scores.copy_(w_prime)
            m.weight.data.copy_(q(m.quant_scores))

    # test_model(model, test_set, val_loader)

    epochs = 20
    weight_params = [param for name, param in model.named_parameters() if 'bn' not in name]
    bn_params = [param for name, param in model.named_parameters() if 'bn' in name]
    opt = torch.optim.Adam([
        dict(params=weight_params, weight_decay=1/500),
        dict(params=bn_params)
    ], lr=0.0001)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
    criterion = torch.nn.CrossEntropyLoss(reduction='none')

    layer_names = {layer: name for name, layer in model.named_modules()}

    # optimizer states & matrix M
    states = {}

    step = 0
    for epoch in range(epochs):
        print(f"Epoch {epoch}.")

        model.train()
        loss_accum = torch.tensor(0.0, device='cuda')
        correct_count = torch.tensor(0, device='cuda')
        for X, y in tqdm(train_loader, desc=f'Epoch {epoch}'):
            X = X.cuda(non_blocking=True)
            y = y.cuda(non_blocking=True)
            model.zero_grad()

            # calculate covariance matrices
            # calculate argmin & quantize
            logits = process_inputs(model, X, states, layer_names, writer, step)

            loss = criterion(logits, y)
            loss.mean().backward()

            # modify gradients
            for name, layer in model.named_modules():
                if isinstance(layer, ScaledWeightConv2d):
                    modify_gradients(layer, name, states[name]['M'], writer, step, lamda=LAMDA)
                    del states[name]['M']

            # update w_prime
            opt.step()

            loss_sum = loss.sum().data
            loss_accum += loss_sum
            n_correct = (torch.argmax(logits, dim=1) == y).sum().data
            correct_count += n_correct

            step_loss = (loss_sum / X.shape[0]).item()
            step_acc = (n_correct / X.shape[0]).item()
            tqdm.write(f'loss = {step_loss}')
            tqdm.write(f'acc = {step_acc:.4f}')
            writer.add_scalar(f'step_loss', step_loss, step)
            writer.add_scalar(f'step_acc', step_acc, step)
            step += 1

        epoch_loss = (loss_accum / len(train_set)).item()
        epoch_acc = (correct_count / len(train_set)).item()
        tqdm.write(f'epoch {epoch} loss = {epoch_loss}')
        tqdm.write(f'epoch {epoch} acc = {epoch_acc:.4f}')
        writer.add_scalar(f'epoch_loss', epoch_loss, epoch)
        writer.add_scalar(f'epoch_acc', epoch_acc, epoch)
        scheduler.step()

        # set running stats
        set_running_stats(model, train_loader)
        test_loss, test_acc = test_model(model, test_set, val_loader)
        writer.add_scalar(f'test_loss', test_loss, epoch)
        writer.add_scalar(f'test_acc', test_acc, epoch)

        # evaluate underlying float model
        for name, m in model.named_modules():
            if isinstance(m, ScaledWeightConv2d):
                m.use_float = True
        set_running_stats(model, train_loader)
        test_loss, test_acc = test_model(model, test_set, val_loader)
        writer.add_scalar(f'float_test/loss', test_loss, epoch)
        writer.add_scalar(f'float_test/acc', test_acc, epoch)
        for name, m in model.named_modules():
            if isinstance(m, ScaledWeightConv2d):
                m.use_float = False

        filename = f"{writer.log_dir}/checkpoint-{epoch}.pt"
        torch.save(model.state_dict(), filename)
        print("Saved in", filename)


def set_running_stats(model, train_loader):
    model.train()
    with torch.no_grad():
        for X, y in tqdm(train_loader, desc='Setting running stats...'):
            model(X.cuda(non_blocking=True))


def test_model(model, test_set, val_loader):
    model.eval()
    print("Test")
    loss_accum = torch.tensor(0.0, device='cuda')
    correct_count = torch.tensor(0, device='cuda')
    criterion = torch.nn.CrossEntropyLoss(reduction='sum')
    with torch.no_grad():
        for X, y in val_loader:
            X = X.cuda(non_blocking=True)
            y = y.cuda(non_blocking=True)
            logits = model(X)
            loss = criterion(logits, y)
            loss_accum += loss.data
            correct_count += (torch.argmax(logits, dim=1) == y).sum()
        loss = (loss_accum / len(test_set)).item()
        acc = (correct_count / len(test_set)).item()
        print(f'test loss = {loss}')
        print(f'test acc = {acc:.4f}')
    return loss, acc


def main():
    DATASET_DIR = '/home/xxx/DBQ/'
    train_set = CIFAR100(DATASET_DIR, train=True, download=True,
                         transform=transforms.Compose([
                             transforms.RandomCrop(32, padding=4),
                             transforms.RandomHorizontalFlip(),
                             transforms.ToTensor(),
                             transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
                         ]))
    test_set = CIFAR100(DATASET_DIR, train=False, download=True,
                        transform=transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
                        ]))
    train_loader = DataLoader(train_set, batch_size=128, num_workers=4, pin_memory=True)
    val_loader = DataLoader(test_set, batch_size=128, num_workers=4, pin_memory=True)

    train(train_set, train_loader, test_set, val_loader)


if __name__ == "__main__":
    main()
