import copy
import pdb
from datetime import datetime
from functools import partial
from pathlib import Path
from typing import OrderedDict

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import CIFAR100
from tqdm import tqdm
from argparse import ArgumentParser
from utils import clip_grad_norm_

import mobilenetv2
from dbq import update_temp
from quantized_modules import init_quant_params, TernaryConv2d, TernaryLinear
import random
import yaml

from wage_quantizer import QW


def main():
    parser = ArgumentParser()
    parser.add_argument('--agg', default=False, action='store_true')
    parser.add_argument('--epochs', type=int, default=50)
    parser.add_argument('--prox', default=False, action='store_true')
    parser.add_argument('--ternary', default=False, action='store_true')
    parser.add_argument('--prox_coeff', type=float, default=1e-4)
    args = parser.parse_args()

    exp_name = Path(f'/v11/xxx/exps_fixedact/{datetime.now().strftime("%Y%m%d-%H%M%S")}-{random.randint(0, 1000):04d}')
    exp_name.mkdir(parents=True, exist_ok=True)
    with open(exp_name / 'args.yaml', 'w') as fp:
        yaml.safe_dump(vars(args), fp)
    print(exp_name)

    fp_model = mobilenetv2.ws_mobilenetv2(num_classes=100).cuda().train()
    print(fp_model)

    wl_weight = 2
    wl_activate = 8
    qu_model = mobilenetv2.simpleq_mobilenetv2(
        num_classes=100,
        wl_weight=wl_weight, wl_activate=wl_activate
    ).cuda().train()
    print(qu_model)

    #### See if pretrained weights work well
    if False:
        state_dict = torch.load("mobilenetv2_cifar100_pretrained.ckpt", map_location='cpu')['state_dict']
        state_dict = {k[6:]: p for k, p in state_dict.items() if 'quantizer' not in k}
        fp_model.load_state_dict(state_dict, strict=False)
        qu_model.load_state_dict(state_dict, strict=False)

    scales = {}
    for name, fp_m, qu_m in dict_zip(dict(fp_model.named_modules()), dict(qu_model.named_modules())):
        if isinstance(fp_m, nn.Conv2d) or isinstance(fp_m, nn.Linear):
            scales[name] = qu_m.weight.data.abs().max() / fp_m.weight.data.abs().max()
            fp_m.weight.data.copy_(qu_m.weight.data)

    ensure_quant(qu_model, wl_weight)

    DATASET_DIR = '.'
    train_data = CIFAR100(DATASET_DIR, train=True, download=True,
                          transform=transforms.Compose([
                              transforms.RandomCrop(32, padding=4),
                              transforms.RandomHorizontalFlip(),
                              transforms.ToTensor(),
                              transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
                          ]))
    test_data = CIFAR100(DATASET_DIR, train=False, download=True,
                         transform=transforms.Compose([
                             transforms.ToTensor(),
                             transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
                         ]))

    train_loader = DataLoader(train_data, batch_size=128, shuffle=True,
                              drop_last=(len(train_data) % 128 == 1),  # For BN
                              num_workers=2, persistent_workers=True, pin_memory=True)
    test_loader = DataLoader(test_data, batch_size=128,
                             num_workers=2, persistent_workers=True, pin_memory=True)

    epochs = args.epochs
    fp_optim = torch.optim.SGD(fp_model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
    fp_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(fp_optim, T_max=epochs)
    qu_optim = torch.optim.SGD(qu_model.parameters(), lr=1.0, momentum=0.9, weight_decay=1e-4)
    qu_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(qu_optim, T_max=epochs)

    fp_tr_hist = []
    fp_te_hist = []
    qu_tr_hist = []
    qu_te_hist = []

    fp_te_hist.append(test_epoch(test_loader, fp_model, 'FP', final_scale=scales['linear']))
    qu_te_hist.append(test_epoch(test_loader, qu_model, 'QU'))

    l2_weights_fp = copy.deepcopy(dict(fp_model.named_parameters())) if args.prox else None
    l2_weights_qu = copy.deepcopy(dict(qu_model.named_parameters())) if args.prox else None

    for epoch in range(epochs):
        # Update temperature parameter
        new_temp = 5.0 + epoch * 2.5
        qu_model.apply(partial(update_temp, new_temp=new_temp))
        print(f"new temp = {new_temp}")

        if epoch % 10 == 0:
            print(f"Saving model to {exp_name / f'model-{epoch}.pth'}")
            torch.save({
                'qu_model': qu_model.state_dict(),
                'fp_model': fp_model.state_dict()
            }, exp_name / f'model-{epoch}.pth')

        qu_tr_hist.append(train_epoch(
            train_loader, f"QU Epoch {epoch}", qu_model, qu_optim,
            l2_weights_qu, args.prox_coeff, normalize_grad=True))
        ensure_quant(qu_model, wl_weight)

        fp_tr_hist.append(train_epoch(
            train_loader, f"FP Epoch {epoch}", fp_model, fp_optim,
            l2_weights_fp, args.prox_coeff, normalize_grad=False,
            final_scale=scales['linear'], scales=scales))

        fp_te_hist.append(test_epoch(test_loader, fp_model, 'FP', final_scale=scales['linear']))
        qu_te_hist.append(test_epoch(test_loader, qu_model, 'QU'))

        if args.agg:
            print("aggregating weights")
            qu_modules = dict(qu_model.named_modules())
            aggregated_weights = set()
            for name, fp_module in fp_model.named_modules():
                qu_module = qu_modules[name]
                if isinstance(qu_module, nn.Conv2d) or isinstance(qu_module, nn.Linear):
                    weight_name = name + '.weight'
                    aggregated_weights.add(weight_name)
                    qu_weight = qu_module.weight
                    avg = (fp_module.weight + qu_weight) / 2
                    if not args.prox:
                        fp_module.weight.data.copy_(avg)
                        qu_module.weight.data.copy_(avg)
                    else:
                        l2_weights_fp[name + '.weight'].data.copy_(avg)
                        l2_weights_qu[name + '.weight'].data.copy_(avg)
            qu_params = dict(qu_model.named_parameters())
            for name, fp_param in fp_model.named_parameters():
                if name in aggregated_weights or name not in qu_params:
                    continue
                qu_param = qu_params[name]
                avg = (qu_param + fp_param) / 2
                if not args.prox:
                    fp_param.data.copy_(avg)
                    qu_param.data.copy_(avg)
                else:
                    l2_weights_fp[name].data.copy_(avg)
                    l2_weights_qu[name].data.copy_(avg)

        torch.save({
            'fp_tr_hist': fp_tr_hist,
            'fp_te_hist': fp_te_hist,
            'qu_tr_hist': qu_tr_hist,
            'qu_te_hist': qu_te_hist,
        }, exp_name / f'results-{epoch}.pth')

        fp_scheduler.step()
        qu_scheduler.step()

    print(f"Saving model to {exp_name / f'model-final.pth'}")
    torch.save({
        'qu_model': qu_model.state_dict(),
        'fp_model': fp_model.state_dict()
    }, exp_name / f'model-final.pth')


@torch.no_grad()
def ensure_quant(model: nn.Module, wl_weight):
    for name, m in model.named_modules():
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
            m.weight.data.copy_(QW(m.weight.data, wl_weight))


def squared_l2_dist(params_a, params_b):
    return torch.cat([((a - b) ** 2).flatten() for _, a, b in dict_zip(params_a, params_b)]).sum()


def train_epoch(train_loader, desc, model, optim, l2_weights, prox_coeff, normalize_grad,
                final_scale=1.0, scales=None):
    model.train()

    sum_train_loss = 0
    ncorrect = 0
    inst_count = 0
    criterion = torch.nn.CrossEntropyLoss()
    for X, Y in tqdm(train_loader, desc=desc):
        X = X.cuda(non_blocking=True)
        Y = Y.cuda(non_blocking=True)

        model.zero_grad()

        logits = model(X) / final_scale
        loss = criterion(logits, Y)
        loss_sum = loss
        if l2_weights is not None:
            loss_sum += prox_coeff * squared_l2_dist(dict(model.named_parameters()), l2_weights)
        loss_sum.backward()

        if normalize_grad:
            for p in model.parameters():
                p.grad.copy_(p.grad / p.grad.abs().max())
        if scales is not None:
            for name, m in model.named_modules():
                if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                    m.weight.grad.mul_(scales[name]**2)

        sum_train_loss += loss.data * X.shape[0]
        ncorrect += (torch.argmax(logits, dim=1) == Y).sum().data
        inst_count += X.shape[0]

        clip_grad_norm_(model.parameters(), 0.5)

        optim.step()
    train_loss = sum_train_loss / inst_count
    train_acc = ncorrect / inst_count

    print(f"TRAIN {train_loss:.4f}, acc={train_acc * 100:.4f}%")
    return train_loss, train_acc


@torch.no_grad()
def test_epoch(test_loader, model, name, final_scale=1.0):
    model.eval()

    sum_test_loss = 0
    ncorrect = 0
    inst_count = 0
    criterion = torch.nn.CrossEntropyLoss()
    for X, Y in test_loader:
        X = X.cuda(non_blocking=True)
        Y = Y.cuda(non_blocking=True)

        logits = model(X) / final_scale
        loss = criterion(logits, Y)

        sum_test_loss += loss.data * X.shape[0]
        ncorrect += (torch.argmax(logits, dim=1) == Y).sum().data
        inst_count += X.shape[0]

    test_loss = sum_test_loss / inst_count
    test_acc = ncorrect / inst_count

    print(f"{name} TEST  {test_loss:.4f}, acc={test_acc * 100:.4f}%")
    return test_loss, test_acc


def dict_zip(*dcts):
    if not dcts:
        return
    for i in set(dcts[0]).intersection(*dcts[1:]):
        yield (i,) + tuple(d[i] for d in dcts)


def copy_back(source, target):
    if hasattr(target, 'data'):
        target.data.copy_(source, non_blocking=True)
    else:
        assert type(source) == type(target), f"{type(source)} != {type(target)}"
        if type(target) == dict:
            assert set(source.keys()) == set(target.keys())
            for k, s, t in dict_zip(source, target):
                copy_back(s, t)
        elif type(target) == list:
            assert len(source) == len(target)
            for s, t in zip(source, target):
                copy_back(s, t)
        elif type(target) == tuple:
            assert len(source) == len(target)
            for s, t in zip(source, target):
                copy_back(s, t)
        elif type(target) == OrderedDict:
            assert set(source.keys()) == set(target.keys())
            for k, s, t in dict_zip(dict(source), dict(target)):
                copy_back(s, t)
        elif type(target) in [int, float, bool, str]:
            pass
        else:
            assert False, f"Unsupported type {type(target)}, {type(source)}"


if __name__ == "__main__":
    main()
