import torch
from pytorch_lightning.core import LightningModule
from lgdea import builder
from lgdea.models.prototype_model import DiagnosticSemanticPrototypeModel


class PrototypePretrainModel(LightningModule):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.save_hyperparameters(cfg)
        self.lr = cfg.lightning.trainer.lr
        self.dm = None
        self.model = DiagnosticSemanticPrototypeModel(cfg)
        self.reconstruction_weight = getattr(cfg.model, 'reconstruction_weight', 1.0)

        if hasattr(torch, 'compile') and getattr(cfg.model, 'use_compile', True):
            try:
                print("Using torch.compile to optimize model...")
                self.model = torch.compile(self.model, mode='reduce-overhead')
            except Exception as e:
                print(f"torch.compile failed: {e}, continuing without compilation")

    def configure_optimizers(self):
        optimizer = builder.build_optimizer(self.cfg, self.lr, self.model)
        scheduler = builder.build_scheduler(self.cfg, optimizer, self.dm)
        return {"optimizer": optimizer, "lr_scheduler": scheduler}

    def training_step(self, batch, batch_idx):
        loss = self.shared_step(batch, "train")
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self.shared_step(batch, "val")
        return loss

    def shared_step(self, batch, split):
        outputs = self.model(
            input_ids=batch['input_ids'],
            attention_mask=batch['attention_mask'],
            token_type_ids=batch['token_type_ids'],
            num_evidences=batch['num_evidences']
        )

        reconstruction_loss = self.model.compute_reconstruction_loss(
            evidence_embeddings=outputs['evidence_embeddings'],
            reconstructed_embeddings=outputs['reconstructed_embeddings'],
            num_evidences=batch['num_evidences']
        )

        total_loss = self.reconstruction_weight * reconstruction_loss

        log_iter = (split == "train")
        self.log(f"{split}_loss", total_loss, on_epoch=True, on_step=log_iter, logger=True, prog_bar=True)
        self.log(f"{split}_reconstruction_loss", reconstruction_loss, on_epoch=True, on_step=log_iter, logger=True)

        return total_loss

    def get_report_representations(self, batch):
        self.eval()
        with torch.no_grad():
            outputs = self.model(
                input_ids=batch['input_ids'],
                attention_mask=batch['attention_mask'],
                token_type_ids=batch['token_type_ids'],
                num_evidences=batch['num_evidences']
            )
            return outputs['report_representations']

