from argparse import Namespace
from functools import partial

import pytorch_lightning as pl
import pytorch_lightning.plugins as plug
import torch.multiprocessing
from pytorch_lightning.loggers import TensorBoardLogger
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import CIFAR10, CIFAR100, ImageNet

from args import parse_args
from dbq import *
from quantized_modules import *
import mobilenetv2
from resnet import resnet20

torch.multiprocessing.set_sharing_strategy('file_system')


class FPModule(pl.LightningModule):
    def __init__(self, args: Namespace):
        super().__init__()
        self.save_hyperparameters()

        if args.dataset == 'cifar10':
            num_classes = 10
        elif args.dataset == 'cifar100':
            num_classes = 100
        elif args.dataset == 'imagenet':
            num_classes = 1000
        else:
            assert False

        if args.model == 'resnet20':
            model = resnet20
        elif args.model == 'mobilenetv2':
            model = mobilenetv2.dbq_mobilenetv2
        else:
            model = None

        self.model = model(
            quantize_first_layer=args.include_first_layer,
            quantize_last_layer=args.include_last_layer,
            quantize_shortcut_layer=args.include_shortcut_layer,
            num_classes=num_classes,
            decompose=args.decompose,
            num_branches=args.num_branches,
            num_branches_first=args.num_branches_first,
            num_branches_last=args.num_branches_last,
            gen_matrix_every_step=args.gen_matrix_every_step
        )
        self.criterion = torch.nn.CrossEntropyLoss()
        self.args = args

    def forward(self, x):
        return self.model(x)

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(
            self.parameters(),
            lr=self.args.learning_rate * self.args.gpus,
            weight_decay=self.args.weight_decay,
            momentum=0.9,
        )
        return optimizer

    def training_epoch_end(self, outputs):
        e = self.current_epoch + 1
        optim = self.optimizers()
        new_lr = self.args.learning_rate * self.args.gpus / 2 * (
                1 + math.cos(e / self.args.epochs * math.pi))
        for g in optim.param_groups:
            g['lr'] = new_lr
        if self.local_rank == 0:
            print(f"New lr = {new_lr}")
        self.log('lr', new_lr)

    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        pred = self.model(x)
        loss = self.criterion(pred, y)
        self.log('train_loss', loss, on_epoch=True)
        return loss

    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        pred = self.model(x)
        loss = self.criterion(pred, y)
        acc_1, = accuracy(pred, y)
        self.log('val_loss', loss, on_step=False, on_epoch=True)
        self.log('val_top1', acc_1, logger=True, on_step=False, on_epoch=True)
        return loss

    def on_load_checkpoint(self, checkpoint):
        # opt_dict = checkpoint['optimizer_states']
        state_dict = checkpoint["state_dict"]
        model_state_dict = self.state_dict()
        for k in list(model_state_dict):
            if k not in state_dict:
                print("Skipping", k, "because it doesn't exist")
                state_dict[k] = model_state_dict[k]
                # del opt_dict[k]
        dropped_keys = []
        for k in state_dict:
            if k in model_state_dict:
                if state_dict[k].shape != model_state_dict[k].shape:
                    print(f"Skip loading parameter: {k}, "
                          f"required shape: {model_state_dict[k].shape}, "
                          f"loaded shape: {state_dict[k].shape}")
                    state_dict[k] = model_state_dict[k]
                    # del opt_dict[k]
            else:
                print(f"Dropping parameter {k}")
                dropped_keys.append(k)
        for k in dropped_keys:
            del state_dict[k]
            # del opt_dict[k]
        # if is_changed:
        # checkpoint.pop("optimizer_states", None)


def lr_scheduler(e, args):
    if e < args.warmup_steps:
        lr = (e + 1) * args.fine_tune_lr / args.warmup_steps
    else:
        lr = args.fine_tune_lr / 2 * (1 + math.cos(e / args.fine_tune_epochs * math.pi))
    return lr


class DBQ(pl.LightningModule):
    def __init__(self, args: Namespace):
        super().__init__()
        self.save_hyperparameters()
        self.criterion = torch.nn.CrossEntropyLoss()
        self.args = args
        self.model = None
        self.remove_counts = None

    def set_model(self, model):
        self.model = model
        self.remove_counts = np.array([lr_scheduler(e, self.args) for e in range(self.args.fine_tune_epochs)])
        self.remove_counts *= self.args.remove_portion * total_quant_groups(self.model) / np.sum(self.remove_counts)
        self.remove_counts = np.round(self.remove_counts).astype(int)
        print(f"remove counts = {self.remove_counts}")

    def forward(self, x):
        return self.model(x)

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(
            self.parameters(),
            lr=self.args.fine_tune_lr * self.args.gpus / self.args.warmup_steps,
            weight_decay=self.args.weight_decay,
            momentum=0.9,
        )
        return optimizer

    def training_epoch_end(self, outputs):
        e = self.current_epoch
        optim = self.optimizers()
        new_lr = lr_scheduler(e, self.args) * self.args.gpus
        for g in optim.param_groups:
            g['lr'] = new_lr
        self.log('lr', new_lr)

    def update_epoch(self):
        e = self.current_epoch

        # Shrink model
        if self.remove_counts[e] > 0:
            scores, errors, saved_bits = shrink_model(self.model, self.remove_counts[e],
                                                      verbose=self.args.print_horizontal_remove)
            self.log('remaining_values', total_n_values(self.model))
            self.logger.experiment.add_histogram('scores', scores, self.global_step)
            self.logger.experiment.add_histogram('errors', errors, self.global_step)
            self.logger.experiment.add_histogram('saved_bits', saved_bits, self.global_step)

        num_branches = torch.cat([
            m.num_branches() for m in self.model.modules() if isinstance(m, Quantizer)
        ])
        self.logger.experiment.add_histogram('num_branches', num_branches, self.global_step)
        self.log('mean_num_branches', num_branches.float().mean())

        # Update temperature parameter
        new_temp = self.args.temp_init + e * self.args.temp_inc
        self.model.apply(partial(update_temp, new_temp=new_temp))
        print(f"new temp = {new_temp}")
        self.log('temp', new_temp)

        # Reorder threshold values
        self.model.apply(update_quant_matrix)

        # Record metrics
        n_values = []
        for k, m in self.model.named_modules():
            if isinstance(m, TernaryConv2d) or isinstance(m, TernaryLinear):
                n_values.append(m.quantizer.n_values())
        n_values = torch.cat(n_values)
        self.logger.experiment.add_histogram('n_values', n_values, self.global_step)

        # Record model size
        theoretical_size = measure_size(self.model)
        actual_size = measure_size(self.model, actual_size=True)
        non_shrinkable = measure_size(self.model, count_only_non_shrinkable=True)
        quant_hyperparams = measure_size(self.model, count_only_quant_hyperparams=True)
        non_dbq = measure_size(self.model, count_non_dbq_modules=True)
        self.log('theoretical_net_size', theoretical_size)
        self.log('actual_net_size', actual_size)
        self.log('non_shrinkable_net_size', non_shrinkable)
        self.log('quant_hyperparams_net_size', quant_hyperparams)
        self.log('non_dbq_net_size', non_dbq)
        print(f"net_size = {theoretical_size:,.2f}B")

        self.logger.experiment.flush()

    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        pred = self.model(x)
        loss = self.criterion(pred, y)
        self.log('train_loss', loss.item(), on_epoch=True)
        return loss

    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        # self.model.train()
        pred = self.model(x)
        loss = self.criterion(pred, y)
        acc_1, = accuracy(pred, y)
        self.log('val_loss', loss, on_step=False, on_epoch=True)
        self.log('val_top1', acc_1, prog_bar=True, on_step=False, on_epoch=True)
        return {'loss': loss.item(), 'acc_1': acc_1.item()}

    def validation_epoch_end(self, outputs):
        if self.local_rank == 0:
            print(f"val acc_1 = {np.mean([m['acc_1'] for m in outputs])}%")
            # Update quantizers and temperature parameter
            self.update_epoch()


def report_model_size(net):
    theoretical_size = measure_size(net)
    actual_size = measure_size(net, actual_size=True)
    non_shrinkable = measure_size(net, count_only_non_shrinkable=True)
    quant_hyperparams = measure_size(net, count_only_quant_hyperparams=True)
    non_dbq = measure_size(net, count_non_dbq_modules=True)
    print(f"theoretical {theoretical_size:,.2f} bytes, actual {actual_size:,.0f} bytes")
    print(f"\t\t non-shrinkable {non_shrinkable:,.0f} bytes ({non_shrinkable / actual_size * 100:,.2f}%)")
    print(f"\t\t quant-hyperparams {quant_hyperparams:,.0f} bytes ({quant_hyperparams / actual_size * 100:,.2f}%)")
    print(f"\t\t non-dbq {non_dbq:,.0f} bytes ({non_dbq / actual_size * 100:,.2f}%)")


@torch.no_grad()
def accuracy(output, target, topk=(1,)):
    maxk = min(max(topk), output.shape[1])
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


def main():
    # parse args
    args = parse_args()

    # data
    DATASET_DIR = '.'

    if args.dataset == 'cifar10':
        dataset = CIFAR10
    elif args.dataset == 'cifar100':
        dataset = CIFAR100
    elif args.dataset == 'imagenet':
        dataset = ImageNet
    else:
        assert False

    if args.gpus is None:
        args.gpus = torch.cuda.device_count()

    train_set = dataset(DATASET_DIR, train=True, download=True,
                        transform=transforms.Compose([
                            transforms.RandomCrop(32, padding=4),
                            transforms.RandomHorizontalFlip(),
                            transforms.ToTensor(),
                            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
                        ]))
    test_set = dataset(DATASET_DIR, train=False, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
                       ]))
    train_loader = DataLoader(train_set, batch_size=args.batch_size // args.gpus,
                              persistent_workers=True, num_workers=4)
    val_loader = DataLoader(test_set, batch_size=args.batch_size // args.gpus,
                            persistent_workers=True, num_workers=4)

    tb_logger = TensorBoardLogger("logs/" + args.exp_name)

    if args.fine_tune_from is None:
        # training
        if args.resume is None:
            model = FPModule(args)
        else:
            model = FPModule.load_from_checkpoint(args.resume, args=args)
        trainer = pl.Trainer(
            gpus=args.gpus,
            plugins=plug.DDPSpawnPlugin(find_unused_parameters=False),
            limit_train_batches=1.0,  # How much of training dataset to check
            max_epochs=args.epochs,
            gradient_clip_val=0.5,
            terminate_on_nan=True,
            resume_from_checkpoint=args.resume,
            logger=tb_logger,
        )
        trainer.fit(model, train_loader, val_loader)
        torch.save(model.model.state_dict(), f"{tb_logger.log_dir}/model_fp.pt")

    elif args.fine_tune_from == "None":
        model = FPModule(args)
    else:
        model = FPModule.load_from_checkpoint(args.fine_tune_from, args=args)

    net = model.model
    print(net)

    n_groups = get_num_groups(net)
    print(f"Total number of quantization groups: {n_groups}")
    print(f"Before quantization: ")
    report_model_size(net)

    # Enable quantization
    net.cuda()
    net.train()

    # Enable quantization
    net.cuda()
    net.train()
    net.apply(partial(
        init_quant_params,
        temp_init=args.temp_init,
        init_epochs=args.init_epochs,
        unfreeze_quant=not args.freeze_quant,
        adapt_params=not args.dry_run,
    ))

    print(f"After quantization: ")
    report_model_size(net)

    # fine-tuning
    if not args.dry_run:
        model.model.train()
        model = DBQ(args)
        model.set_model(net)
        # shrink_model(net, 1)
        trainer = pl.Trainer(
            gpus=args.gpus,
            plugins=plug.DDPSpawnPlugin(find_unused_parameters=False),
            limit_train_batches=1.0,  # How much of training dataset to check
            max_epochs=args.fine_tune_epochs,
            gradient_clip_val=0.5,
            terminate_on_nan=True,
            num_sanity_val_steps=-1,
            logger=tb_logger,
            # sync_batchnorm=True
            # progress_bar_refresh_rate=0
        )
        trainer.fit(model, train_loader, val_loader)
        torch.save(model.model.state_dict(), f"{tb_logger.log_dir}/model.pt")


if __name__ == "__main__":
    main()
