import time
from functools import partial
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple

import einops
import lightning as L
import torch
from torch import Tensor, nn

class LitModule(L.LightningModule):
    def __init__(
        self,
        orchestrator: torch.nn.Module,
        optimizer: torch.optim.Optimizer,
        scheduler: torch.optim.lr_scheduler = None,
        compile: bool = False,
        metrics: dict = {},
#        log_interval = 10,
    ):
        super().__init__()
        self.save_hyperparameters(logger=False, ignore=["orchestrator"])
        self.orchestrator: torch.nn.Module = orchestrator
        self.metrics: dict = metrics
#       self.log_interval = log_interval
        self.start_time = time.time()

    def forward(self, *args, **kwargs):
        assert False, "dummy method to satisfy PTL"
    
    def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Dict[str, Tensor]:
        loss = self.orchestrator.train_step(batch)
        
        if torch.isnan(loss):
            raise ValueError('NaN encountered')
        
        # Log metrics
        batch_size = batch['field'].shape[0]
        self.log("train/loss", loss, prog_bar=True, batch_size=batch_size, sync_dist=True)

        cur_time = time.time() - self.start_time
        self.log("time/elapsed_time", cur_time, on_step=True, on_epoch=False)
        self.log("logging_step", self.global_step, reduce_fx=torch.min, on_step=True, on_epoch=False)

#        if self.global_step % self.log_interval == 0:
#            print(f'train/loss at {self.global_step}: {loss}')
        # for k, v in log_dict.items():
        #     self.log(f"train/{k}", v, batch_size=batch_size)
        return loss

    def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Dict[str, Tensor]:
        loss = self.orchestrator.train_step(batch)
        batch_size = batch['field'].shape[0]
        self.log("val/loss", loss, prog_bar=True, batch_size=batch_size, sync_dist=True)
#        print(f'val_loss: {loss}')
        # for k, v in log_dict.items():
        #     self.log(f"val/{k}", v, batch_size=batch_size)
        return loss

    def setup(self, stage: str):
        if self.hparams.compile and stage == "fit":
            self.orchestrator = torch.compile(self.orchestrator)

    def configure_optimizers(self) -> Dict[str, Any]:
        # Exclude weight decay from normalization and bias parameters
        exclude = lambda n: "bias" in n or "norm" in n
        include = lambda n: not exclude(n)
        named_parameters = list(self.orchestrator.named_parameters())
        norm_or_bias_params = [p for n, p in named_parameters if exclude(n) and p.requires_grad]
        rest_params = [p for n, p in named_parameters if include(n) and p.requires_grad]
        optimizer = self.hparams.optimizer(
            [
                {"params": norm_or_bias_params, "weight_decay": 0.0},
                {
                    "params": rest_params,
                    "weight_decay": self.hparams.optimizer.keywords["weight_decay"],
                },
            ]
        )
        if self.hparams.scheduler is not None:
            # if (type(self.hparams.scheduler) is partial) and (
            #     self.hparams.scheduler.func.__name__ == "LinearWarmupCosineAnnealingLR"
            # ):
            #     interval = "step"
            # else:
            #     interval = "epoch"
            interval = 'step'
            scheduler = self.hparams.scheduler(optimizer=optimizer)
            return {
                "optimizer": optimizer,
                "lr_scheduler": {
                    "scheduler": scheduler,
                    "monitor": "val/loss",
                    "interval": interval,
                    "frequency": 1,
                },
            }
        return {"optimizer": optimizer}
