import argparse
from argparse import Namespace
from functools import partial

import pytorch_lightning as pl
import pytorch_lightning.plugins as plug
import torch.multiprocessing
from pytorch_lightning.loggers import TensorBoardLogger
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import CIFAR10, CIFAR100, ImageNet

import mobilenetv2
from args import parse_args
from dbq import *
from quantized_modules import *
import mobilenetv2_binact
from resnet import resnet20
from hessian_eigenthings import compute_hessian_eigenthings

torch.multiprocessing.set_sharing_strategy('file_system')


# knowledge distillation loss
@torch.jit.script
def distillation(logits, labels, teacher_scores, T: float, alpha: float):
    # distillation loss + classification loss
    # y: student
    # labels: hard label
    # teacher_scores: soft label
    task_loss = F.cross_entropy(logits, labels, reduction='none')
    teacher_loss = F.kl_div(
        F.log_softmax(logits / T, dim=-1),
        F.log_softmax(teacher_scores / T, dim=-1),
        reduction='none',
        log_target=True
    ).sum(1)
    return task_loss * (1 - alpha) + teacher_loss * (2 * T ** 2 + alpha)


def load_model(path):
    model = mobilenetv2.MobileNetV2(num_classes=100, final_bias_trick=False)
    state_dict = torch.load(path, map_location='cpu')
    state_dict = {k: p for k, p in state_dict.items() if 'quantizer' not in k}
    if 'linear.bias' not in state_dict:
        state_dict['linear.bias'] = state_dict['linear.weight'][:, -1]
        state_dict['linear.weight'] = state_dict['linear.weight'][:, :-1]
    model.load_state_dict(state_dict)
    return model


class FPModule(pl.LightningModule):
    def __init__(self, args: Namespace):
        super().__init__()
        self.save_hyperparameters()

        if args.dataset == 'cifar10':
            num_classes = 10
        elif args.dataset == 'cifar100':
            num_classes = 100
        elif args.dataset == 'imagenet':
            num_classes = 1000
        else:
            assert False

        self.model = mobilenetv2_binact.MobileNetV2(
            num_classes=num_classes,
            decompose=args.decompose,
            final_bias_trick=False,
        )
        self.teacher = load_model('/d1/xxx/TBQ/TBQ_experiments/model_fp_0.pt')
        for p in self.teacher.parameters():
            p.requires_grad = False
        self.criterion = torch.nn.CrossEntropyLoss()
        self.args = args

    def forward(self, x):
        return self.model(x)

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(
            self.parameters(),
            lr=self.args.learning_rate * self.args.gpus,
            weight_decay=self.args.weight_decay,
            momentum=0.9,
        )
        self.scheduler = ReduceLROnPlateau(
            optimizer, 'max', factor=0.5, patience=5, min_lr=1e-7)
        return optimizer

    def training_epoch_end(self, outputs):
        e = self.current_epoch + 1
        optim = self.optimizers()

        #new_lr = self.args.learning_rate * self.args.gpus / 2 * (
        #        1 + math.cos(e / self.args.epochs * math.pi))
        #new_lr = (1.0 - e / self.args.epochs) * self.args.learning_rate * self.args.gpus
        self.trainer: pl.Trainer
        val_top1 = self.trainer.callback_metrics["val_top1"].item()
        if self.local_rank == 0:
            print("val_top1=", val_top1)
        self.scheduler.step(val_top1)
        #for g in optim.param_groups:
        #    g['lr'] = new_lr
        if self.local_rank == 0:
            print(f"New lr = {optim.param_groups[0]['lr']}")
        self.log('lr', optim.param_groups[0]['lr'])

        if self.local_rank == 0:
            eigenvals, eigenvecs = compute_hessian_eigenthings(
                self.model, self.trainer.train_dataloader, F.cross_entropy, 20)
            

    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        logits = self.model(x)

        with torch.no_grad():
            teacher_scores = self.teacher(x)
            teacher_scores = self.trainer.training_type_plugin.reduce(teacher_scores)

        loss = distillation(logits, y, teacher_scores, 20.0, 0.7).mean()

        self.log('train_loss', loss, on_epoch=True)
        return loss

    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        pred = self.model(x)
        loss = self.criterion(pred, y)
        acc_1, = accuracy(pred, y)
        self.log('val_loss', loss, on_step=False, on_epoch=True)
        self.log('val_top1', acc_1, logger=True, on_step=False, on_epoch=True)
        return loss


@torch.no_grad()
def accuracy(output, target, topk=(1,)):
    maxk = min(max(topk), output.shape[1])
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


def main():
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.enabled = True

    # parse args
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch_size', '-b', default=100, type=int)
    parser.add_argument('--weight_decay', '-wd', default=1e-4, type=float)
    parser.add_argument('--epochs', '-e', default=200, type=int)
    parser.add_argument('--fine_tune_epochs', '-fe', default=50, type=int)
    parser.add_argument('--learning_rate', '-lr', default=0.1, type=float)
    parser.add_argument('--fine_tune_lr', '-flr', default=0.0025, type=float)
    parser.add_argument('--temp_init', '-t', default=5.0, type=float)
    parser.add_argument('--temp_inc', '-tinc', default=2.5, type=float)
    parser.add_argument('--init_epochs', '-ie', default=100, type=int)
    parser.add_argument('--resume', '-r', default=None, type=str)
    parser.add_argument('--fine_tune_from', '-ft', default=None, type=str)
    parser.add_argument('--gpus', '-g', default=None, type=int)
    parser.add_argument('--warmup_steps', '-w', default=5, type=int)
    parser.add_argument('--freeze_quant', '-fzq', default=False, action='store_true')
    parser.add_argument('--include_first_layer', '-if', default=False, action='store_true')
    parser.add_argument('--include_last_layer', '-il', default=False, action='store_true')
    parser.add_argument('--include_shortcut_layer', '-is', default=False, action='store_true')
    parser.add_argument('--decompose', default=None, type=int)
    parser.add_argument('--num_branches_first', '-nbf', default=2, type=int)
    parser.add_argument('--num_branches', '-nb', default=2, type=int)
    parser.add_argument('--num_branches_last', '-nbl', default=2, type=int)
    parser.add_argument('--dry_run', default=False, action='store_true')
    parser.add_argument('--gen_matrix_every_step', '-gm', default=False, action='store_true')
    parser.add_argument('--remove_portion', '-rp', default=0.9, type=float)
    parser.add_argument('--one_by_one', default=False, action='store_true')
    parser.add_argument('--dataset', '-d', required=True, choices=[
        'cifar10', 'cifar100', 'imagenet'
    ])
    args = parser.parse_args()

    args.exp_name = f'dbq_d{args.dataset}_nb{args.num_branches}'
    if args.include_first_layer:
        args.exp_name += f'_nbf{args.num_branches_first}'
    if args.include_last_layer:
        args.exp_name += f'_nbl{args.num_branches_last}'

    # data
    DATASET_DIR = '.'

    if args.dataset == 'cifar10':
        dataset = CIFAR10
    elif args.dataset == 'cifar100':
        dataset = CIFAR100
    elif args.dataset == 'imagenet':
        dataset = ImageNet
    else:
        assert False

    if args.gpus is None:
        args.gpus = torch.cuda.device_count()

    train_set = dataset(DATASET_DIR, train=True, download=True,
                        transform=transforms.Compose([
                            transforms.RandomCrop(32, padding=4),
                            transforms.RandomHorizontalFlip(),
                            transforms.ToTensor(),
                            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
                        ]))
    test_set = dataset(DATASET_DIR, train=False, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
                       ]))
    train_loader = DataLoader(train_set, batch_size=args.batch_size // args.gpus,
                              persistent_workers=True, num_workers=4, pin_memory=True)
    val_loader = DataLoader(test_set, batch_size=args.batch_size // args.gpus,
                            persistent_workers=True, num_workers=4, pin_memory=True)

    tb_logger = TensorBoardLogger("pretrain_logs/" + args.exp_name)

    model = FPModule(args)
    if args.resume:
        model.model.load_state_dict(torch.load(args.resume, map_location="cpu"))
    trainer = pl.Trainer(
        gpus=args.gpus,
        plugins=plug.DDPSpawnPlugin(find_unused_parameters=False),
        limit_train_batches=1.0,  # How much of training dataset to check
        max_epochs=args.epochs,
        #gradient_clip_val=0.5,
        #terminate_on_nan=True,
        #resume_from_checkpoint=args.resume,
        logger=tb_logger,
    )
    trainer.fit(model, train_loader, val_loader)
    torch.save(model.model.state_dict(), f"{tb_logger.log_dir}/model_fp.pt")


if __name__ == "__main__":
    main()
