import torch
import torch.nn as nn
import torch.nn.functional as F
from optimizer.scheduler import WarmupMultistepLR
from optimizer.lbfgs import LBFGSOptimizer
import torchvision
from pytorch_lightning.metrics.functional import accuracy
from vit_pytorch import ViT

from argparse import ArgumentParser

import pytorch_lightning as pl


def create_model(args):
    if args.dataset == "cifar10":
        num_classes = 10
        image_size = 32
    elif args.dataset == "imagenet":
        num_classes = 1000
        image_size = 224
    elif args.dataset == "flowers102":
        num_classes = 102
        image_size = 224
    else:
        raise ValueError("Unsupported dataset {}".format(args.dataset))
    model = ViT(
        image_size=image_size,
        patch_size=args.vit_patch_size,
        num_classes=num_classes,
        dim=args.vit_dim,
        depth=args.vit_depth,
        heads=args.vit_heads,
        mlp_dim=args.vit_mlp_dim,
        dropout=args.vit_dropout,
        emb_dropout=args.vit_emb_dropout,
    )
    return model


class LitVit(pl.LightningModule):
    def __init__(self, args, **kwargs):
        super().__init__(**kwargs)

        self.save_hyperparameters()
        self.model = create_model(args)

    def forward(self, x):
        out = self.model(x)
        return F.log_softmax(out, dim=1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = F.log_softmax(self.model(x), dim=1)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)
        self.log("train_loss", loss)
        self.log("train_acc", acc)
        return loss

    def evaluate(self, batch, stage=None):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)

        if stage:
            self.log(f"{stage}_loss", loss, prog_bar=True)
            self.log(f"{stage}_acc", acc, prog_bar=True)

    def validation_step(self, batch, batch_idx):
        self.evaluate(batch, "val")

    def test_step(self, batch, batch_idx):
        self.evaluate(batch, "test")

    def configure_optimizers(self):
        print("hparams: ", self.hparams)
        if self.hparams.args.optimizer == "sgd":
            optimizer = torch.optim.SGD(
                self.parameters(),
                lr=self.hparams.args.lr,
                momentum=self.hparams.args.momentum,
                weight_decay=self.hparams.args.weight_decay,
            )

            scheduler = WarmupMultistepLR(
                optimizer,
                milestones=self.hparams.args.lr_decay_at,
                warmup_period=self.hparams.args.lr_num_warmup_epochs,
                gamma=self.hparams.args.lr_gamma,
            )

        if self.hparams.args.optimizer == "lbfgs":
            master_params = self.parameters()
            optimizer = LBFGSOptimizer(
                master_params,
                lr=self.hparams.args.lr,
                momentum=self.hparams.args.momentum,
                weight_decay=self.hparams.args.weight_decay,
                mm_p=self.hparams.args.stat_decay_param,
                mm_g=self.hparams.args.stat_decay_grad,
                update_freq=self.hparams.args.update_freq,
                hist_sz=self.hparams.args.history_size,
                damping=self.hparams.args.lbfgs_damping,
                kl_clip=self.hparams.args.grad_clip,
                kl_clip_fix_scaling=self.hparams.args.grad_clip_fix_scaling,
            )

            scheduler = WarmupMultistepLR(
                optimizer,
                milestones=self.hparams.args.lr_decay_at,
                warmup_period=self.hparams.args.lr_num_warmup_epochs,
                gamma=self.hparams.args.lr_gamma,
            )
        return [optimizer], [scheduler]

    @staticmethod
    def add_model_specific_args(parent_parser):

        parser = ArgumentParser(parents=[parent_parser], add_help=False)

        # Basic training params
        parser.add_argument(
            "--dataset",
            type=str,
            default="cifar10",
            choices=["cifar10", "imagenet"],
            help="optimization method",
        )
        parser.add_argument(
            "--optimizer",
            type=str,
            default="sgd",
            choices=["sgd", "lbfgs"],
            help="optimization method",
        )
        parser.add_argument(
            "--lr",
            default=0.05, 
            type=float, 
            help="Learning rate"
        )
        parser.add_argument(
            "--batch_size", 
            default=32, 
            type=int, 
            help="Batch size"
        )
        parser.add_argument(
            "--momentum", 
            default=0.9, 
            type=float, 
            help="momentum"
        )
        parser.add_argument(
            "--lr_gamma", 
            default=0.1, 
            type=float, 
            help="Amount to decrease step size"
        )
        parser.add_argument(
            "--lr_decay_at",
            nargs="+",
            type=int,
            default=[35, 75, 90],
            help="epoch intervals to decay lr (default: [35, 75, 90])",
        )
        parser.add_argument(
            "--lr_num_warmup_epochs",
            type=int,
            default=5,
            help="number of warmup epochs (default: 5)",
        )
        parser.add_argument(
            "--weight_decay",
            "--wd",
            default=1e-4,
            type=float,
            help="weight decay (default: 1e-4)",
        )

        # SLIM-QN hyperparams
        parser.add_argument(
            "--stat_decay_param",
            default=0.9,
            type=float,
            help="stat decay for parameters",
        )
        parser.add_argument(
            "--stat_decay_grad",
            default=0.9,
            type=float,
            help="stat decay for gradients",
        )
        parser.add_argument(
            "--update_freq",
            default=200,
            type=int,
            help="update frequency for Hessian approximation",
        )
        parser.add_argument(
            "--history_size",
            default=20,
            type=int,
            help="history size for LBFGS-related vectors",
        )
        parser.add_argument(
            "--lbfgs_damping", 
            default=0.2, 
            type=float, 
            help="LBFGS damping factor"
        )

        # Other training params
        parser.add_argument(
            "--grad_clip", 
            default=0.05, 
            type=float, 
            help="gradient clipping"
        )
        parser.add_argument(
            "--grad_clip_fix_scaling",
            default=False,
            action="store_true",
            help="Gradient clipping strength is scaled by current learning rate (default, False) or fixed based on initial learning rate (True, must be set)",
        )

        # Model specific arguments
        parser.add_argument(
            "--vit_patch_size",
            default=4,
            type=int,
            help="Number of patches per image dimension (typical values: CIFAR10: 4, ImageNet: 16)",
        )
        parser.add_argument(
            "--vit_dim", 
            default=512, 
            type=int, 
            help="Embedding dimension D"
        )
        parser.add_argument(
            "--vit_depth", 
            default=6, 
            type=int, 
            help="Number of transformer stacks"
        )
        parser.add_argument(
            "--vit_heads", 
            default=8, 
            type=int, 
            help="Number of attention heads"
        )
        parser.add_argument(
            "--vit_mlp_dim", 
            default=512, 
            type=int, 
            help="MLP dimension"
        )
        parser.add_argument(
            "--vit_dropout", 
            default=0.1, 
            type=float, 
            help="Dropout probability"
        )
        parser.add_argument(
            "--vit_emb_dropout",
            default=0.1,
            type=float,
            help="Embedding dropout probability",
        )

        return parser
