import torch
import torchvision
import lightning as pl
from lightning.pytorch.utilities import grad_norm
from hydra.utils import instantiate
from einops import rearrange

from src.utils.misc import fold, unfold


class BaseLightningModel(pl.LightningModule):
    def seq_first(self, xs):
        if torch.is_tensor(xs):
            xs = rearrange(xs, 'batch seq ... -> seq batch ...')
        else:
            xs = torch.stack(xs, dim=0)
        return xs

    def preprocess(self, batch):
        xs, params, params2 = batch
        if len(params.shape) > 1 and params.shape[1] == xs.shape[1]:
            params = self.seq_first(params)
        xs = self.seq_first(xs) 
        return xs, params, params2

    def training_step(self, batch, batch_idx):
        batch = self.preprocess(batch)
        loss, _ = self.loss(batch)
        for loss_type, loss_value in loss.items():
            self.log(f'loss/{loss_type}/train', loss_value)
        # if batch_idx == 0:
        #     self.save_image(x_triplets, 'train')
        return loss['total']

    def validation_step(self, batch, batch_idx, dataloader_idx=0):
        batch = self.preprocess(batch)
        split = 'test' if dataloader_idx == 0 else f'ood{dataloader_idx}'
        is_first_batch = batch_idx == 0
        self._validation_step(batch, is_first_batch, split)

    def _validation_step(self, batch, is_first_batch, split):
        loss, imgs = self.loss(batch)
        for loss_type, loss_value in loss.items():
            self.log(f'loss/{loss_type}/{split}', loss_value, sync_dist=True, prog_bar=True, add_dataloader_idx=False)
        if is_first_batch:
            for i, img in enumerate(imgs):
                self.save_image(img, split, i, 'prediction')

    def save_image(self, img, split, step, title, num_imgs=8):
        img, _ = fold(img[:, :num_imgs])
        img = torchvision.utils.make_grid(img, normalize=True, nrow=num_imgs, pad_value=1)
        self.logger.experiment.add_image(f'{title}/{split}', img, step)

    # def on_before_optimizer_step(self, optimizer):
    #     '''
    #     Log the gradient norm and parameter norm to tensorboard.
    #     '''
    #     if self.trainer.global_step % 25 != 0:  # don't make the tf file huge
    #         return 0
    #     for tag, value in self.named_parameters():
    #         param_norm = value.detach().pow(2).mean().sqrt()
    #         self.logger.experiment.add_scalar(f'param_norm/{tag}', param_norm, self.global_step)
    #         if tag == 'encoder.slot_embed':
    #             for k, v in enumerate(value.detach().pow(2).mean(dim=-1).sqrt().tolist()):
    #                 self.logger.experiment.add_scalar(f'param_norm/{tag}.{k}', v, self.global_step)
    #         if value.grad is not None:
    #             # self.logger.experiment.add_histogram(f'grad_std/{tag}', value.grad.std(), self.global_step)
    #             self.logger.experiment.add_scalar(f'grad_norm/{tag}', torch.linalg.vector_norm(value.grad), self.global_step)


    def loss(self, xs):
        return dict(total=0), xs

    def configure_optimizers(self):
        optim = {}
        optim['optimizer'] = instantiate(self.hparams.optimizer, params=self.parameters())
        if 'scheduler' in self.hparams:
            # concat the 'optim' dict with self.hparams.scheduler_extra which contains the extra args for the scheduler as class attributes.
            optim['lr_scheduler'] = {
                'scheduler': instantiate(self.hparams.scheduler, optimizer=optim['optimizer']),
                **self.hparams.scheduler_extra
            }
        return optim

    # def configure_optimizers(self):
    #     optimizer = instantiate(self.hparams.optimizer, params=self.parameters())
    #     return optimizer