import numpy as np

import torch
from torch import nn
import torch.nn.functional as F

import lightning as L

from medmnist import Evaluator

from .modules import RMSNorm
from .transformer import HighOrderTransformer

def mean_pool(x):
    return x.mean(dim=(1,2,3))

def flatten_pool(x):
    return x.flatten(start_dim=1)

class ViHOT(L.LightningModule):
    def __init__(
        self,
        d_hidden,
        n_blocks,
        n_head,
        n_class,
        evaluators,
        kernel=4,
        dropout=0.,
        factorized=True,
        pooling='mean',
        use_linear_att=True,
        feature_map='SMReg',
        lr=1e-3,
        ce_weight=None,
        attention_ignore_list=None
    ):
        super().__init__()
        self.save_hyperparameters()
        self.d_hidden = d_hidden
        self.lr = lr
        self.n_class = n_class

        self.conv = nn.Sequential(
            nn.Conv3d(1, d_hidden, kernel_size=kernel, padding=(kernel - 1)//2),
            nn.ReLU(),
            nn.BatchNorm3d(d_hidden),
            nn.Dropout(p=dropout),
            nn.Conv3d(d_hidden, d_hidden, kernel_size=kernel, padding=kernel//2, stride=2),
            nn.ReLU(),
            nn.BatchNorm3d(d_hidden),
            nn.Dropout(p=dropout),
            nn.Conv3d(d_hidden, d_hidden, kernel_size=kernel, padding=(kernel - 1)//2),
            nn.ReLU(),
            nn.BatchNorm3d(d_hidden),
            nn.Dropout(p=dropout),
            nn.Conv3d(d_hidden, d_hidden, kernel_size=kernel, padding=kernel//2, stride=2),
            nn.ReLU(),
            nn.BatchNorm3d(d_hidden),
            nn.Dropout(p=dropout)
        )

        self.encoder = HighOrderTransformer(
            d_hidden,
            n_blocks,
            n_head,
            dropout,
            factorized,
            use_linear_att,
            feature_map,
            rotary_emb_list=[1, 2, 3],
            ignore_list=attention_ignore_list
        )
        self.pooling_fn = mean_pool if pooling == 'mean' else flatten_pool
        d_pool = d_hidden if pooling == 'mean' else d_hidden * (28 // 4)**3
        self.head = nn.Sequential(
            RMSNorm(d_pool),
            nn.Linear(
                d_pool,
                n_class
            )
        )
        self.criterion = nn.CrossEntropyLoss(weight=ce_weight)
        self.evaluators = evaluators
        self.step_outputs = {'train' : [], 'val':[], 'test':[]}


    def forward(self, x):                                     # (bs, 1, width, height, depth)
        h = self.conv(x.float()).permute(0, 2, 3, 4, 1)       # (bs, width, height, depth, d)
        h =  self.encoder(h)
        h = self.pooling_fn(h)                                # (bs, d)
        return self.head(h)


    def step(self, batch, mode='train'):
        x, y = batch
        logits = self.forward(x)
        loss = self.criterion(logits, y.squeeze(1))
        self.log(f"{mode}_loss", loss.item())
        probs = F.softmax(logits, dim=-1)

        self.step_outputs[mode].append(probs)
        return loss

    def on_epoch_end(self, mode):
        all_preds = torch.cat(self.step_outputs[mode], dim=0).detach().cpu().numpy()
        auc, acc = self.evaluators[mode].evaluate(all_preds)
        self.log(f"{mode}_auc", auc)
        self.log(f"{mode}_acc", acc)
        self.step_outputs[mode].clear()

    def training_step(self, batch, batch_idx):
        return self.step(batch, mode='train')

    def validation_step(self, batch, batch_idx):
        return self.step(batch, mode='val')

    def test_step(self, batch, batch_idx):
        return self.step(batch, mode='test')

    def on_train_epoch_end(self):
        self.on_epoch_end(mode='train')

    def on_validation_epoch_end(self):
        self.on_epoch_end(mode='val')

    def on_test_epoch_end(self):
        self.on_epoch_end(mode='test')

    def count_parameters(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr)
        # scheduler = torch.optim.lr_scheduler.MultiStepLR(
            # optimizer, milestones=[50, 75], gamma=0.1
        # )
        return [optimizer]# [scheduler]