import torch
import lightning as L 
from utils import metrics
import model.lob_embedding as lob_embedding

class LOBAutoEncoder(L.LightningModule):
    def __init__(self,**kwargs):
        super().__init__()
        for key, value in kwargs.items():
            setattr(self, key, value)
        if hasattr(self, 'lob_embed_name'):
            self.lob_to_embed = True
            self.set_lob_embedding()
        else:
            self.lob_to_embed = False
        
    def set_lob_embedding(self):
        lob_class = getattr(lob_embedding, self.lob_embed_name)
        self.lob_embed = lob_class.load_from_checkpoint(self.lob_embed_path, **self.lob_embed_arguments)
        self.lob_embed.eval()
        for p in self.lob_embed.parameters():
            p.requires_grad = False
        print(f"Loaded LOB embedding model: {self.lob_embed_name} from {self.lob_embed_path}")
    
    def forward(self, x):
        raise NotImplementedError("This method should be implemented in the child class")
    
    def encode(self, x):
        raise NotImplementedError("Must implement encoder in the child class")
    
    def on_train_epoch_start(self):
        if hasattr(self.trainer.datamodule, "on_train_epoch_start"):
            self.trainer.datamodule.on_train_epoch_start()
            
    def training_step(self, batch):
        src, label = batch  # label: fraud type
        if self.lob_to_embed:
            with torch.no_grad():
                lob_repr = self.lob_embed.encode(src[..., :20])
            manual_feat = src[..., 20:]
            lob_repr_expand = lob_repr.unsqueeze(1).expand(-1, manual_feat.shape[1], -1)  # [batch, seq_len, embed_dim]
            x = torch.cat([lob_repr_expand, manual_feat], dim=-1)
        else:
            x = src
        z = self.encode(x)
        out = self.forward(x)
        return self.logger_metrics(out, x, z, label, "train_")

    def validation_step(self, batch):
        src, label = batch  # label: fraud type
        if self.lob_to_embed:
            with torch.no_grad():
                lob_repr = self.lob_embed.encode(src[..., :20])
            manual_feat = src[..., 20:]
            lob_repr_expand = lob_repr.unsqueeze(1).expand(-1, manual_feat.shape[1], -1)  # [batch, seq_len, embed_dim]
            x = torch.cat([lob_repr_expand, manual_feat], dim=-1)
        else:
            x = src
        z = self.encode(x)
        out = self.forward(x)
        return self.logger_metrics(out, x, z, label, "val_")

    def test_step(self, batch):
        src, label = batch  # label: fraud type
        if self.lob_to_embed:
            with torch.no_grad():
                lob_repr = self.lob_embed.encode(src[..., :20])
            manual_feat = src[..., 20:]
            lob_repr_expand = lob_repr.unsqueeze(1).expand(-1, manual_feat.shape[1], -1)  # [batch, seq_len, embed_dim]
            x = torch.cat([lob_repr_expand, manual_feat], dim=-1)
        else:
            x = src
        z = self.encode(x)
        out = self.forward(x)
        return self.logger_metrics(out, x, z, label, "test_")

    def configure_optimizers(self):
        optimizer = getattr(torch.optim, self.optimizer_name)(self.parameters(), lr=self.lr)
        return {"optimizer": optimizer}

    def compute_loss(self, reconstructed, target, embeddings, labels, log_type = "train_"):
        loss_func = metrics[self.metrics[0]]
        if log_type == "train_":
            loss = loss_func(reconstructed=reconstructed, target=target, embeddings=embeddings, labels=labels)
        else:
            with torch.no_grad():
                loss = loss_func(reconstructed=reconstructed, target=target, embeddings=embeddings, labels=labels)
        return loss
    
    def logger_metrics(self, reconstructed, target, embeddings, labels, log_type="train_"):
        loss = self.compute_loss(reconstructed, target, embeddings, labels)
        logs = {
            log_type+self.metrics[0]: loss
        }
        with torch.no_grad():
            for mtc in self.metrics[1:]:
                m_loss = metrics[mtc](reconstructed=reconstructed, target=target, embeddings=embeddings, labels=labels)
                logs[log_type+mtc] = m_loss
                
        if log_type == "train_" and self.global_step % self.log_freq == 0:
            self.logger.log_metrics(logs,step = self.global_step)
        else:
            self.logger.log_metrics(logs,step = self.global_step)
        
        return loss 