import copy
from datetime import datetime
from functools import partial
from pathlib import Path
from typing import OrderedDict

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import CIFAR100
from tqdm import tqdm
from argparse import ArgumentParser
from utils import clip_grad_norm_

import mobilenetv2
from dbq import update_temp
from quantized_modules import init_quant_params
import random
import yaml


def main():
    parser = ArgumentParser()
    parser.add_argument('--agg', default=False, action='store_true')
    parser.add_argument('--epochs', type=int, default=50)
    parser.add_argument('--prox', default=False, action='store_true')
    parser.add_argument('--ternary', default=False, action='store_true')
    parser.add_argument('--prox_coeff', type=float, default=1e-4)
    parser.add_argument('--act_quant', default=False, action='store_true')
    parser.add_argument('--branches', type=int, default=2)
    parser.add_argument('--dw', default=False, action='store_true')
    parser.add_argument('--no-include-first-last', default=False, action='store_true')
    args = parser.parse_args()

    exp_name = Path(f'/v11/xxx/exps_fixedact/{datetime.now().strftime("%Y%m%d-%H%M%S")}-{random.randint(0, 1000):04d}')
    exp_name.mkdir(parents=True, exist_ok=True)
    with open(exp_name / 'args.yaml', 'w') as fp:
        yaml.safe_dump(vars(args), fp)
    print(exp_name)

    fp_model = mobilenetv2.MobileNetV2(num_classes=100).cuda().train()
    print(fp_model)

    qu_model = mobilenetv2.dbq_mobilenetv2(
        num_classes=100,
        num_branches=args.branches,
        num_branches_first=args.branches,
        num_branches_last=args.branches,
        quantize_first_layer=(not args.no_include_first_last),
        quantize_last_layer=(not args.no_include_first_last),
        quantize_conv2_layer=True,
        quantize_activations=args.act_quant,
        quantize_expansion_layer=True,
        quantize_depthwise_layer=args.dw,
        quantize_pointwise_layer=True
    ).cuda().train()
    print(qu_model)

    #### See if pretrained weights work well
    state_dict = torch.load("mobilenetv2_cifar100_pretrained.ckpt", map_location='cpu')['state_dict']
    state_dict = {k[6:]: p for k, p in state_dict.items() if 'quantizer' not in k}
    qu_model.load_state_dict(state_dict, strict=False)
    state_dict = {k: p for k, p in state_dict.items() if 'quantizer' not in k}
    fp_model.load_state_dict(state_dict, strict=False)

    # enable quantizers
    qu_model.apply(partial(
        init_quant_params,
        temp_init=2.5,
        init_epochs=5,
        unfreeze_quant=True,
        adapt_params=True,
    ))

    DATASET_DIR = '.'
    train_data = CIFAR100(DATASET_DIR, train=True, download=True,
                          transform=transforms.Compose([
                              transforms.RandomCrop(32, padding=4),
                              transforms.RandomHorizontalFlip(),
                              transforms.ToTensor(),
                              transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
                          ]))
    test_data = CIFAR100(DATASET_DIR, train=False, download=True,
                         transform=transforms.Compose([
                             transforms.ToTensor(),
                             transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
                         ]))

    train_loader = DataLoader(train_data, batch_size=128, shuffle=True,
                              drop_last=(len(train_data) % 128 == 1),  # For BN
                              num_workers=2, persistent_workers=True, pin_memory=True)
    test_loader = DataLoader(test_data, batch_size=128,
                             num_workers=2, persistent_workers=True, pin_memory=True)

    epochs = args.epochs
    fp_optim = torch.optim.SGD(fp_model.parameters(), lr=0.0025, momentum=0.9, weight_decay=1e-4)
    fp_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(fp_optim, T_max=epochs)
    qu_optim = torch.optim.SGD(qu_model.parameters(), lr=0.0025, momentum=0.9, weight_decay=1e-4)
    qu_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(qu_optim, T_max=epochs)

    fp_tr_hist = []
    fp_te_hist = []
    qu_tr_hist = []
    qu_te_hist = []

    fp_te_hist.append(test_epoch(test_loader, fp_model))
    qu_te_hist.append(test_epoch(test_loader, qu_model))

    l2_weights = copy.deepcopy(dict(fp_model.named_parameters())) if args.prox else None

    for epoch in range(epochs):
        # Update temperature parameter
        new_temp = 5.0 + epoch * 2.5
        qu_model.apply(partial(update_temp, new_temp=new_temp))
        print(f"new temp = {new_temp}")

        if epoch % 10 == 0:
            print(f"Saving model to {exp_name / f'model-{epoch}.pth'}")
            torch.save({
                'qu_model': qu_model.state_dict(),
                'fp_model': fp_model.state_dict()
            }, exp_name / f'model-{epoch}.pth')

        fp_tr_hist.append(train_epoch(train_loader, f"FP Epoch {epoch}", fp_model, fp_optim, l2_weights, args.prox_coeff))
        qu_tr_hist.append(train_epoch(train_loader, f"QU Epoch {epoch}", qu_model, qu_optim, l2_weights, args.prox_coeff))

        fp_te_hist.append(test_epoch(test_loader, fp_model))
        qu_te_hist.append(test_epoch(test_loader, qu_model))

        if args.agg:
            print("aggregating weights")
            qu_modules = dict(qu_model.named_modules())
            aggregated_weights = set()
            for name, fp_module in fp_model.named_modules():
                qu_module = qu_modules[name]
                if isinstance(qu_module, nn.Conv2d) or isinstance(qu_module, nn.Linear):
                    weight_name = name + '.weight'
                    aggregated_weights.add(weight_name)
                    qu_weight = qu_module.weight
                    if args.ternary:
                        qu_weight = qu_module.get_weight()
                    avg = (fp_module.weight + qu_weight) / 2
                    if not args.prox:
                        fp_module.weight.data.copy_(avg)
                        qu_module.weight.data.copy_(avg)
                    else:
                        l2_weights[name + '.weight'].data.copy_(avg)
            qu_params = dict(qu_model.named_parameters())
            for name, fp_param in fp_model.named_parameters():
                if name in aggregated_weights:
                    continue
                qu_param = qu_params[name]
                avg = (qu_param + fp_param) / 2
                if not args.prox:
                    fp_param.data.copy_(avg)
                    qu_param.data.copy_(avg)
                else:
                    l2_weights[name].data.copy_(avg)

        torch.save({
            'fp_tr_hist': fp_tr_hist,
            'fp_te_hist': fp_te_hist,
            'qu_tr_hist': qu_tr_hist,
            'qu_te_hist': qu_te_hist,
        }, exp_name / f'results-{epoch}.pth')

        fp_scheduler.step()
        qu_scheduler.step()

    print(f"Saving model to {exp_name / f'model-final.pth'}")
    torch.save({
        'qu_model': qu_model.state_dict(),
        'fp_model': fp_model.state_dict()
    }, exp_name / f'model-final.pth')


def squared_l2_dist(params_a, params_b):
    return torch.cat([((a - b) ** 2).flatten() for _, a, b in dict_zip(params_a, params_b)]).sum()


def train_epoch(train_loader, desc, model, optim, l2_weights, prox_coeff):
    model.train()

    sum_train_loss = 0
    ncorrect = 0
    inst_count = 0
    criterion = torch.nn.CrossEntropyLoss()
    for X, Y in tqdm(train_loader, desc=desc):
        X = X.cuda(non_blocking=True)
        Y = Y.cuda(non_blocking=True)

        model.zero_grad()

        logits = model(X)
        loss = criterion(logits, Y)
        loss_sum = loss
        if l2_weights is not None:
            loss_sum += prox_coeff * squared_l2_dist(dict(model.named_parameters()), l2_weights)
        loss_sum.backward()

        sum_train_loss += loss.data * X.shape[0]
        ncorrect += (torch.argmax(logits, dim=1) == Y).sum().data
        inst_count += X.shape[0]

        clip_grad_norm_(model.parameters(), 0.5)

        optim.step()
    train_loss = sum_train_loss / inst_count
    train_acc = ncorrect / inst_count

    print(f"TRAIN {train_loss:.4f}, acc={train_acc * 100:.4f}%")
    return train_loss, train_acc


@torch.no_grad()
def test_epoch(test_loader, model):
    model.eval()

    sum_test_loss = 0
    ncorrect = 0
    inst_count = 0
    criterion = torch.nn.CrossEntropyLoss()
    for X, Y in test_loader:
        X = X.cuda(non_blocking=True)
        Y = Y.cuda(non_blocking=True)

        logits = model(X)
        loss = criterion(logits, Y)

        sum_test_loss += loss.data * X.shape[0]
        ncorrect += (torch.argmax(logits, dim=1) == Y).sum().data
        inst_count += X.shape[0]

    test_loss = sum_test_loss / inst_count
    test_acc = ncorrect / inst_count

    print(f"TEST  {test_loss:.4f}, acc={test_acc * 100:.4f}%")
    return test_loss, test_acc


def dict_zip(*dcts):
    if not dcts:
        return
    for i in set(dcts[0]).intersection(*dcts[1:]):
        yield (i,) + tuple(d[i] for d in dcts)


def copy_back(source, target):
    if hasattr(target, 'data'):
        target.data.copy_(source, non_blocking=True)
    else:
        assert type(source) == type(target), f"{type(source)} != {type(target)}"
        if type(target) == dict:
            assert set(source.keys()) == set(target.keys())
            for k, s, t in dict_zip(source, target):
                copy_back(s, t)
        elif type(target) == list:
            assert len(source) == len(target)
            for s, t in zip(source, target):
                copy_back(s, t)
        elif type(target) == tuple:
            assert len(source) == len(target)
            for s, t in zip(source, target):
                copy_back(s, t)
        elif type(target) == OrderedDict:
            assert set(source.keys()) == set(target.keys())
            for k, s, t in dict_zip(dict(source), dict(target)):
                copy_back(s, t)
        elif type(target) in [int, float, bool, str]:
            pass
        else:
            assert False, f"Unsupported type {type(target)}, {type(source)}"


if __name__ == "__main__":
    main()
