import torch
import torch.nn as nn
import torch.nn.functional as F
from optimizer.scheduler import WarmupMultistepLR
from optimizer.lbfgs import LBFGSOptimizer
import torchvision
from pytorch_lightning.metrics.functional import accuracy
from models.cifar_resnet import get_model as get_cifar_model
from argparse import ArgumentParser

import pytorch_lightning as pl


def create_model(args):
    use_batch_norm = not args.no_bn
    if args.dataset == "cifar10":
        if args.model.lower() not in [
            "resnet20",
            "resnet32",
            "resnet44",
            "resnet56",
            "resnet110",
            "resnet1202",
        ]:
            raise ValueError("Unexpected model type for CIFAR-10.")
        model = get_cifar_model(args.model.lower(), bn=use_batch_norm)
        return model

    elif args.dataset in ["imagenet", "flowers102"]:
        if args.model.lower() not in [
            "resnet18",
            "resnet34",
            "resnet50",
            "resnet101",
            "resnet152",
        ]:
            raise ValueError("Unexpected model type for ImageNet.")

        if use_batch_norm:
            norm_layer = None  # Default behavior: initializes model with batch norm
        else:
            norm_layer = nn.Identity  # Replaces BN with identity mapping

        if args.dataset == "flowers102":
            num_classes = 102
        else:
            num_classes = 1000

        if args.model.lower() == "resnet18":
            model = torchvision.models.resnet18(norm_layer=norm_layer)
        elif args.model.lower() == "resnet34":
            model = torchvision.models.resnet34(norm_layer=norm_layer)
        elif args.model.lower() == "resnet50":
            model = torchvision.models.resnet50(norm_layer=norm_layer)
        elif args.model.lower() == "resnet101":
            model = torchvision.models.resnet101(norm_layer=norm_layer)
        elif args.model.lower() == "resnet152":
            model = torchvision.models.resnet152(norm_layer=norm_layer)

        # Adjust output layer based on dataset
        if num_classes != 1000:
            in_features = model.fc.in_features
            model.fc = torch.nn.modules.Linear(in_features, num_classes)

        return model


class LitResnet(pl.LightningModule):
    def __init__(self, args, **kwargs):
        super().__init__(**kwargs)

        self.save_hyperparameters()
        self.model = create_model(args)

    def forward(self, x):
        out = self.model(x)
        return F.log_softmax(out, dim=1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self.forward(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)
        self.log("train_loss", loss)
        self.log("train_acc", acc)
        return loss

    def evaluate(self, batch, stage=None):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)

        if stage:
            self.log(f"{stage}_loss", loss, prog_bar=True)
            self.log(f"{stage}_acc", acc, prog_bar=True)

    def validation_step(self, batch, batch_idx):
        self.evaluate(batch, "val")

    def test_step(self, batch, batch_idx):
        self.evaluate(batch, "test")

    def configure_optimizers(self):
        if self.hparams.args.optimizer == "sgd":
            optimizer = torch.optim.SGD(
                self.parameters(),
                lr=self.hparams.args.lr,
                momentum=self.hparams.args.momentum,
                weight_decay=self.hparams.args.weight_decay,
            )

            scheduler = WarmupMultistepLR(
                optimizer,
                milestones=self.hparams.args.lr_decay_at,
                warmup_period=self.hparams.args.lr_num_warmup_epochs,
                gamma=self.hparams.args.lr_gamma,
            )

        if self.hparams.args.optimizer == "lbfgs":
            master_params = self.parameters()
            optimizer = LBFGSOptimizer(
                master_params,
                lr=self.hparams.args.lr,
                momentum=self.hparams.args.momentum,
                weight_decay=self.hparams.args.weight_decay,
                mm_p=self.hparams.args.stat_decay_param,
                mm_g=self.hparams.args.stat_decay_grad,
                update_freq=self.hparams.args.update_freq,
                hist_sz=self.hparams.args.history_size,
                damping=self.hparams.args.lbfgs_damping,
                kl_clip=self.hparams.args.grad_clip,
                kl_clip_fix_scaling=self.hparams.args.grad_clip_fix_scaling,
            )

            scheduler = WarmupMultistepLR(
                optimizer,
                milestones=self.hparams.args.lr_decay_at,
                warmup_period=self.hparams.args.lr_num_warmup_epochs,
                gamma=self.hparams.args.lr_gamma,
            )
        return [optimizer], [scheduler]

    @staticmethod
    def add_model_specific_args(parent_parser):

        parser = ArgumentParser(parents=[parent_parser], add_help=False)

        # Basic training params
        parser.add_argument(
            "--model",
            type=str,
            help="ResNet model to be trained. For CIFAR-10: resnet{20,32,44,56,110,1202}, for ImageNet: resnet{18,34,50,101,152} ",
        )
        parser.add_argument(
            "--no_bn",
            action="store_true",
            default=False,
            help="Do not use batch norm (default False)",
        )
        parser.add_argument(
            "--dataset",
            type=str,
            default="cifar10",
            choices=["cifar10"],
            help="optimization method",
        )
        parser.add_argument(
            "--optimizer",
            type=str,
            default="sgd",
            choices=["sgd", "lbfgs"],
            help="optimization method",
        )
        parser.add_argument(
            "--lr", 
            default=0.05, 
            type=float, 
            help="Learning rate"
        )
        parser.add_argument(
            "--batch_size",
            default=32,
            type=int,
            help="Effective batch size. In case of multi-gpu training, batch_size/num_gpus will be loaded onto each gpu.",
        )
        parser.add_argument(
            "--momentum", 
            default=0.9, 
            type=float, 
            help="momentum"
        )
        parser.add_argument(
            "--lr_gamma", 
            default=0.1, 
            type=float, 
            help="Amount to decrease step size"
        )
        parser.add_argument(
            "--lr_decay_at",
            nargs="+",
            type=int,
            default=[35, 75, 90],
            help="epoch intervals to decay lr (default: [35, 75, 90])",
        )
        parser.add_argument(
            "--lr_num_warmup_epochs",
            type=int,
            default=5,
            help="number of warmup epochs (default: 5)",
        )
        parser.add_argument(
            "--weight_decay",
            "--wd",
            default=1e-4,
            type=float,
            help="weight decay (default: 1e-4)",
        )

        # S-LBFGS hyperparams
        parser.add_argument(
            "--stat_decay_param",
            default=0.9,
            type=float,
            help="stat decay for parameters",
        )
        parser.add_argument(
            "--stat_decay_grad",
            default=0.9,
            type=float,
            help="stat decay for gradients",
        )
        parser.add_argument(
            "--update_freq",
            default=200,
            type=int,
            help="update frequency for Hessian approximation",
        )
        parser.add_argument(
            "--history_size",
            default=20,
            type=int,
            help="history size for LBFGS-related vectors",
        )
        parser.add_argument(
            "--lbfgs_damping", 
            default=0.2, 
            type=float, 
            help="LBFGS damping factor"
        )

        # Other training params
        parser.add_argument(
            "--grad_clip", 
            default=0.05, 
            type=float, 
            help="gradient clipping"
        )
        parser.add_argument(
            "--grad_clip_fix_scaling",
            default=False,
            action="store_true",
            help="Gradient clipping strength is scaled by current learning rate (default, False) or fixed based on initial learning rate (True, must be set)",
        )

        return parser
