import pathlib
from argparse import ArgumentParser
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
import torch.nn.functional as F
import math
import numpy as np
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from pytorch_lightning.callbacks import LearningRateMonitor

from cpr.adam_cpr import AdamCPR, group_cpr_parameters

torch.set_float32_matmul_precision('medium')


### Data
def cifar100_task(cache_dir='./data'):
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4, padding_mode='reflect'),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    trainset = torchvision.datasets.CIFAR100(root=cache_dir, train=True, download=True, transform=transform_train)
    testset = torchvision.datasets.CIFAR100(root=cache_dir, train=False, download=True, transform=transform_test)

    return trainset, testset


### Model
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion * planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


### Lightning Module
class ResNetModule(pl.LightningModule):

    def __init__(self, config):
        super().__init__()

        self.cfg = config

        if self.cfg.model_name == "ResNet18":
            self.model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=100)
        elif self.cfg.model_name == "ResNet34":
            self.model = ResNet(BasicBlock, [3, 4, 6, 3], num_classes=100)
        elif self.cfg.model_name == "ResNet50":
            self.model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=100)
        elif self.cfg.model_name == "ResNet101":
            self.model = ResNet(Bottleneck, [3, 4, 23, 3], num_classes=100)
        elif self.cfg.model_name == "ResNet152":
            self.model = ResNet(Bottleneck, [3, 8, 36, 3], num_classes=100)

        self.loss = nn.CrossEntropyLoss(label_smoothing=0.1)

        self.test_stats = []

    def configure_optimizers(self):

        if self.cfg.optimizer == 'adamw':
            optim_cfg = {"lr": self.cfg.lr, "beta1": self.cfg.beta1,
                         "beta2": self.cfg.beta2,
                         "weight_decay": self.cfg.weight_decay}

            param_groups = group_cpr_parameters(self.model, optim_cfg, avoid_keywords=["bias", "bn"])
            optimizer = torch.optim.AdamW(param_groups, lr=self.cfg.lr, betas=(self.cfg.beta1, self.cfg.beta2))
        elif self.cfg.optimizer == 'adamcpr':
            optim_cfg = {"lr": self.cfg.lr, "beta1": self.cfg.beta1,
                         "beta2": self.cfg.beta2, "mode": self.cfg.mode, "kappa": self.cfg.kappa,
                         "kappa_init_dependent": self.cfg.kappa_init_dependent, "lagmul_rate": self.cfg.lagmul_rate,
                         "kappa_init_warm_start": self.cfg.kappa_init_warm_start}

            param_groups = group_cpr_parameters(self.model, optim_cfg, avoid_keywords=["bias", "bn"])
            optimizer = AdamCPR(param_groups, lr=self.cfg.lr, betas=(self.cfg.beta1, self.cfg.beta2),
                                mode=self.cfg.mode,
                                kappa=self.cfg.kappa, lagmul_rate=self.cfg.lagmul_rate,
                                kappa_init_dependent=self.cfg.kappa_init_dependent, apply_decay=None,
                                kappa_adapt=self.cfg.kappa_adapt,
                                kappa_init_warm_start=self.cfg.kappa_init_warm_start)

        if self.cfg.rescale_alpha > 0.1:
            with torch.no_grad():
                for n, p in self.model.named_parameters():
                    if n.endswith("weight"):
                        p.data *= self.cfg.rescale_alpha
                self.rescale_norm = np.sqrt(
                    sum(p.pow(2).sum().item() for n, p in self.model.named_parameters() if n.endswith("weight")))

        lr_decay_factor = self.cfg.lr_decay_factor
        num_warmup_steps = self.cfg.lr_warmup_steps

        def lr_lambda(current_step: int):
            if current_step < num_warmup_steps:
                return float(current_step) / float(max(1, num_warmup_steps))
            else:
                return lr_decay_factor + (1 - lr_decay_factor) * max(0.0, (1 + math.cos(
                    math.pi * (current_step - num_warmup_steps) / float(
                        max(1, self.cfg.max_train_steps - num_warmup_steps)))) / 2)

        lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=-1)

        return [optimizer], {'scheduler': lr_scheduler, 'interval': 'step'}

    def setup(self, stage: str) -> None:
        trainset, testset = cifar100_task()
        self.trainset = trainset
        self.testset = testset

    def train_dataloader(self):
        train_loader = torch.utils.data.DataLoader(self.trainset, batch_size=self.cfg.batch_size, shuffle=True,
                                                   num_workers=8)
        return train_loader

    def test_dataloader(self):
        test_loader = torch.utils.data.DataLoader(self.testset, batch_size=self.cfg.batch_size, shuffle=False,
                                                  num_workers=8)
        return test_loader

    def _accuracy(self, y_hat, y):
        return torch.sum(torch.argmax(y_hat, dim=1) == y).item() / len(y)

    def training_step(self, batch, batch_idx):

        X, y = batch
        y_hat = self.model(X)
        loss = self.loss(y_hat, y)

        self.log('train_loss', loss)

        if self.cfg.rescale_alpha > 0.1:
            with torch.no_grad():
                new_norm = np.sqrt(
                    sum(p.pow(2).sum().item() for n, p in self.model.named_parameters() if n.endswith("weight")))
                for n, p in self.model.named_parameters():
                    if n.endswith("weight"):
                        p.data *= self.rescale_norm / new_norm
        return loss

    def test_step(self, batch, batch_idx):
        X, y = batch
        y_hat = self.model(X)
        loss = self.loss(y_hat, y)

        correct_pred = torch.sum(torch.argmax(y_hat, dim=1) == y).item()
        num_samples = len(y)
        self.test_stats.append({'loss': loss.item(), 'correct_pred': correct_pred, 'num_samples': num_samples})
        self.log('test_loss', loss)

        return loss

    def on_test_epoch_end(self):
        self.log('lr', self.trainer.optimizers[0].param_groups[0]['lr'])

        valid_loss = np.mean([s['loss'] for s in self.test_stats])
        valid_accuracy = np.sum([s['correct_pred'] for s in self.test_stats]) / np.sum(
            [s['num_samples'] for s in self.test_stats])
        self.log('test_loss', valid_loss)
        self.log('test_accuracy', valid_accuracy)
        self.test_stats = []


def train_cifar100_task(config):
    task_name = f"{config.model_name}_seed{config.seed}_steps{config.max_train_steps}"
    expt_dir = pathlib.Path(config.output_dir) / config.session / task_name
    expt_dir.mkdir(parents=True, exist_ok=True)

    if config.optimizer == "adamcpr":
        expt_name = f"{config.optimizer}_l{config.lr}_{config.mode}_k{config.kappa}_k{config.kappa_init_dependent}_s{config.kappa_init_warm_start}_r{config.lagmul_rate}_ada{config.kappa_adapt}"
    else:
        expt_name = f"{config.optimizer}_l{config.lr}_w{config.weight_decay}_re{config.rescale_alpha}"

    (expt_dir / expt_name).mkdir(parents=True, exist_ok=True)
    np.save(expt_dir / expt_name / "config.npy", config.__dict__)
    logger = TensorBoardLogger(save_dir=expt_dir, name=expt_name)
    pl.seed_everything(config.seed)

    if config.device:
        devices = [config.device]
    else:
        devices = [0]

    model = ResNetModule(config)

    callbacks = [
        LearningRateMonitor(logging_interval='step'),
    ]

    trainer = pl.Trainer(devices=devices, accelerator="gpu", max_steps=config.max_train_steps,
                         log_every_n_steps=config.log_interval,
                         enable_progress_bar=config.enable_progress_bar,
                         logger=logger, callbacks=callbacks)
    trainer.fit(model)

    # evaluate model
    result = trainer.test(model)
    np.save(expt_dir / expt_name / "result.npy", result)
    print(result)


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--session", type=str, default='test_adapt_step')
    parser.add_argument("--seed", type=int, default=1)
    parser.add_argument("--max_train_steps", type=int, default=20000)
    parser.add_argument("--batch_size", type=int, default=128)
    parser.add_argument("--model_name", type=str, default="ResNet18")
    parser.add_argument("--optimizer", type=str, default='adamcpr')
    parser.add_argument("--lr", type=float, default=0.001)
    parser.add_argument("--beta1", type=float, default=0.9)
    parser.add_argument("--beta2", type=float, default=0.98)
    parser.add_argument("--weight_decay", type=float, default=0.1)
    parser.add_argument("--lr_warmup_steps", type=int, default=500)
    parser.add_argument("--lr_decay_factor", type=float, default=0.1)
    parser.add_argument("--rescale_alpha", type=float, default=0)
    parser.add_argument("--mode", type=str, default='l2_constrain')
    parser.add_argument("--kappa", type=float, default=10.0)
    parser.add_argument("--kappa_init_dependent", type=float, default=0)
    parser.add_argument("--lagmul_rate", type=float, default=1.0)
    parser.add_argument("--kappa_adapt", type=bool, default=False)
    parser.add_argument("--kappa_init_warm_start", type=int, default=0)
    parser.add_argument("--log_interval", type=int, default=10)
    parser.add_argument("--enable_progress_bar", type=bool, default=True)
    parser.add_argument("--output_dir", type=str, default='cifar100')
    parser.add_argument("--device", type=int, default=0)
    args = parser.parse_args()

    if args.rescale_alpha > 0:
        assert args.optimizer == 'adamw'

    train_cifar100_task(args)
