import torch
from torch import nn
import torch.nn.functional as F
from torchmetrics.functional import f1_score, accuracy, matthews_corrcoef
import lightning as L

from .modules import RMSNorm
from .transformer import HighOrderTransformer

def mean_pool(x):
    return x.mean(dim=2)

def flatten_pool(x):
    return x.flatten(start_dim=2)

class TSHOT(L.LightningModule):
    def __init__(
        self, 
        d_hidden,
        n_blocks, 
        n_head, 
        patch_size=4,
        context_length=None,
        pooling='mean',
        normalize=False,
        dropout=0., 
        use_linear_att=True,
        feature_map='SMReg',
        lr=1e-3,
        weight_decay=0.,
        attention_ignore_list=None,
    ):  
        super().__init__()
        self.save_hyperparameters()
        self.d_hidden = d_hidden
        self.lr = lr
        self.weight_decay = weight_decay
        self.normalize = normalize

        self.conv = nn.Sequential(
            nn.Conv1d(in_channels=1, out_channels=d_hidden, kernel_size=patch_size, stride=patch_size),
            nn.Dropout(dropout),
        )

        self.encoder = HighOrderTransformer(
            d_hidden, 
            n_blocks, 
            n_head, 
            dropout, 
            use_linear_att, 
            feature_map,
            rotary_emb_list=[2],
            ignore_list=attention_ignore_list
        )
        self.pooling_fn = {
            'mean' : mean_pool,
            'flatten' : flatten_pool
        }[pooling]
        d_pool = d_hidden
        if pooling == 'flatten':
            assert context_length is not None
            d_pool = d_hidden * (context_length // patch_size)
        self.head = nn.Sequential(RMSNorm(d_pool), nn.Linear(d_pool, 720))


    def forward(self, x):
        bs, n, t = x.shape

        if self.normalize:
            mu = x.mean(2, keepdim=True)
            std = torch.sqrt(torch.var(x, dim=2, keepdim=True, unbiased=False) + 1e-5)
        else:
            mu = 0.
            std = 1.
        x_norm = (x - mu) / std
        
        h = x_norm.view(bs * n, t).unsqueeze(1)  ## (bs * n, 1, t)
        h = self.conv(h).transpose(1, 2)         ## (bs * n, t, d)
        h = h.view(bs, n, *h.shape[1:])          ## (bs, n, t, d)

        h =  self.encoder(h)                      ## (bs, n, t, d)
        h = self.pooling_fn(h)                    ## (bs, n, d_pool)
        logits = self.head(h)                     ## (bs, n, n_pred)
        return (logits * std) + mu

    def calc_metrics(self, preds, labels):
        return F.mse_loss(preds, labels), F.l1_loss(preds, labels)
    

    def step(self, batch, mode='train'):
        x, y = batch
        preds = self.forward(x)
        for l in [96, 192, 336, 720]:
            mse, mae = self.calc_metrics(preds[:, :, :l], y[:, :, :l])
            self.log(f"{mode}-mse-{l}", mse.item())
            self.log(f"{mode}-mae-{l}", mae.item())
        return mse

    def training_step(self, batch, batch_idx):
        return self.step(batch, mode='train')
    
    def validation_step(self, batch, batch_idx):
        return self.step(batch, mode='val')

    def test_step(self, batch, batch_idx):
        return self.step(batch, mode='test')
    
    def count_parameters(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, eta_min=1e-6, T_max=100)
        return [optimizer], [lr_scheduler]

    
