import torch
from pytorch_lightning.core import LightningModule
from lgdea import builder
from lgdea.models.stage2_model import Stage2LesionSemanticModel


class Stage2PretrainModel(LightningModule):

    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.save_hyperparameters(cfg)
        self.lr = cfg.lightning.trainer.lr
        self.dm = None

        stage1_checkpoint_path = getattr(cfg.model, ' ', None)
        if stage1_checkpoint_path is None:
            raise ValueError(" ")

        self.model = Stage2LesionSemanticModel(cfg, stage1_checkpoint_path=stage1_checkpoint_path)
        self.distillation_weight = getattr(cfg.model, 'distillation_weight', 1.0)
        self.consistency_weight = getattr(cfg.model, 'consistency_weight', 1.0)
        self.diversity_weight = getattr(cfg.model, 'diversity_weight', 1.0)

        if hasattr(torch, 'compile') and getattr(cfg.model, 'use_compile', True):
            try:
                print("Using torch.compile to optimize Stage2 model...")
                self.model = torch.compile(self.model, mode='reduce-overhead')
            except Exception as e:
                print(f"torch.compile failed: {e}, continuing without compilation")

    def configure_optimizers(self):
        trainable_params = []

        for param in self.model.vision_encoder.parameters():
            if param.requires_grad:
                trainable_params.append(param)

        trainable_params.append(self.model.queries)

        for param in self.model.query_attention.parameters():
            trainable_params.append(param)

        for param in self.model.vision_projection.parameters():
            trainable_params.append(param)

        for param in self.model.prototype_head.parameters():
            trainable_params.append(param)


        class ParamWrapper:
            def __init__(self, params):
                self.params = params

            def parameters(self):
                return iter(self.params)

        param_wrapper = ParamWrapper(trainable_params)
        optimizer = builder.build_optimizer(self.cfg, self.lr, param_wrapper)
        scheduler = builder.build_scheduler(self.cfg, optimizer, self.dm)
        return {"optimizer": optimizer, "lr_scheduler": scheduler}

    def training_step(self, batch, batch_idx):
        loss = self.shared_step(batch, "train")
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self.shared_step(batch, "val")
        return loss

    def shared_step(self, batch, split):
        """训练/验证步骤"""
        outputs = self.model(
            images=batch['images'],
            input_ids=batch['input_ids'],
            attention_mask=batch['attention_mask'],
            token_type_ids=batch['token_type_ids'],
            num_evidences=batch['num_evidences']
        )

        # 计算语义蒸馏损失
        distillation_loss = self.model.compute_distillation_loss(
            image_prototype_dist=outputs['image_prototype_dist'],
            text_prototype_dist=outputs['text_prototype_dist'],
            num_evidences=batch['num_evidences']
        )

        # 计算病灶级语义一致性损失
        consistency_loss = self.model.compute_lesion_consistency_loss(
            leis=outputs['leis'],
            image_prototype_dist=outputs['image_prototype_dist']
        )

        total_loss = self.distillation_weight * distillation_loss + self.consistency_weight * consistency_loss

        log_iter = (split == "train")
        self.log(f"{split}_loss", total_loss, on_epoch=True, on_step=log_iter, logger=True, prog_bar=True)
        self.log(f"{split}_distillation_loss", distillation_loss, on_epoch=True, on_step=log_iter, logger=True)
        self.log(f"{split}_consistency_loss", consistency_loss, on_epoch=True, on_step=log_iter, logger=True)

        return total_loss

    def get_lesion_representations(self, batch):
        self.eval()
        with torch.no_grad():
            outputs = self.model(
                images=batch['images'],
                input_ids=batch['input_ids'],
                attention_mask=batch['attention_mask'],
                token_type_ids=batch['token_type_ids'],
                num_evidences=batch['num_evidences']
            )
            return {
                'leis': outputs['leis'],
                'image_prototype_dist': outputs['image_prototype_dist'],
                'image_level_dist': outputs['image_level_dist']
            }

