import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import torchmetrics
import lightning as L
import wandb
from lightning.pytorch import LightningModule
from pytorch_lightning.loggers import WandbLogger

import models
from models import S4D_Enc
import argparse

##### Seeding
torch.manual_seed(1111)

## Logging into WandB
wandb.login()
wandb.init()
wandb_logger = WandbLogger()


##### Path to where you have stored the processed data including train_X.pt, /train_y.pt, /train_X.pt if you don't have this data you can download it at this link: https://drive.google.com/drive/folders/1DA2c0ELOFqOPiO4GgY4KetNhsHEvf5tx?dmr=1&ec=wgc-drive-hero-goto
#### Thanks to amazing library of S4 data is processed completly using their repo: https://github.com/state-spaces/s4


def create_data_loaders(path, batch_size=32, shuffle=True):
    # for train dataset
    data_train = torch.load(path + "/train_X.pt")
    label_train = torch.load(path + "/train_y.pt")

    train_dataset = TensorDataset(data_train, label_train)

    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=4
    )

    data_test = torch.load(path + "/test_X.pt")
    label_test = torch.load(path + "/test_y.pt")

    test_dataset = TensorDataset(data_test, label_test)
    test_loader = DataLoader(
        test_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=4
    )

    return train_loader, test_loader


class LitSSMClassifier(LightningModule):
    def __init__(
        self,
        input_dim,
        d_model,
        d_state,
        dropout,
        transposed,
        num_classes,
        depth,
        lr=1e-2,
    ):
        super().__init__()
        self.save_hyperparameters()

        self.model = S4D_Enc(
            input_dim=input_dim,
            d_model=d_model,
            d_state=d_state,
            dropout=dropout,
            transposed=transposed,
            num_classes=num_classes,
            depth=depth,
        )
        self.lr = lr
        self.criterion = torch.nn.CrossEntropyLoss()
        torch.set_float32_matmul_precision("high")
        self.val_acc = torchmetrics.classification.MulticlassAccuracy(
            num_classes=num_classes
        )

    def forward(self, x):
        return self.model(x)  # expects x of shape (B, L, input_dim)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        if batch_idx % 20 == 0:
            wandb.log({"S4D_loss_SC": loss})
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        preds = torch.argmax(logits, dim=1)
        self.val_acc.update(preds, y)

    def on_validation_epoch_end(self):
        acc = self.val_acc.compute()
        wandb.log({"val_acc_epoch": acc})
        self.val_acc.reset()

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)


def main():
    parser = argparse.ArgumentParser(description="Train SSMs on SC_35 and SC_10")

    # --- Paths ---
    parser.add_argument(
        "--checkpoint_path",
        type=str,
        default="",
    )
    parser.add_argument(
        "--data_path",
        type=str,
        default="",
    )

    # --- Data ---
    parser.add_argument("--batch_size", type=int, default=64)
    parser.add_argument("--input_dim", type=int, default=1)
    parser.add_argument("--lr", type=float, default=0.01)
    parser.add_argument("--epochs", type=int, default=40)

    # --- Model ---
    parser.add_argument("--model_dim", type=int, default=128)
    parser.add_argument("--num_layers", type=int, default=4)
    parser.add_argument("--d_state", type=int, default=64)
    parser.add_argument("--dropout", type=float, default=0.0)
    parser.add_argument("--transposed", type=bool, default=True)
    parser.add_argument("--num_classes", type=int, default=35)

    # --- GPUS ---
    parser.add_argument("--gpus", type=int, nargs="+", default=[0, 1, 2, 3])
    args = parser.parse_args()

    model = LitSSMClassifier(
        input_dim=args.input_dim,
        d_model=args.model_dim,
        d_state=args.d_state,
        dropout=args.dropout,
        transposed=args.transposed,
        num_classes=args.num_classes,
        depth=args.num_layers,
        lr=args.lr,
    )

    trainer = L.Trainer(
        devices=args.gpus,
        accelerator="gpu",
        max_epochs=args.epochs,
        default_root_dir=args.checkpoint_path,
    )

    train_loader, test_loader = create_data_loaders(path=args.data_path)

    trainer.fit(model, train_loader, test_loader)


if __name__ == "__main__":
    main()
