import torch
from pytorch_lightning.core import LightningModule
from lgdea import builder
from lgdea.models.stage3_relation_model import Stage3RelationPropagationModel


class Stage3RelationModel(LightningModule):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.save_hyperparameters(cfg)
        self.lr = cfg.lightning.trainer.lr
        self.dm = None
        stage1_checkpoint_path = getattr(cfg.model, 'stage1_checkpoint_path', None)
        stage2_checkpoint_path = getattr(cfg.model, 'stage2_checkpoint_path', None)

        if stage1_checkpoint_path is None:
            raise ValueError("stage1_checkpoint_path must be provided in config")
        if stage2_checkpoint_path is None:
            raise ValueError("stage2_checkpoint_path must be provided in config")

        self.model = Stage3RelationPropagationModel(
            cfg,
            stage1_checkpoint_path=stage1_checkpoint_path,
            stage2_checkpoint_path=stage2_checkpoint_path
        )

        self.infonce_weight = getattr(cfg.model, 'infonce_weight', 1.0)

        if hasattr(torch, 'compile') and getattr(cfg.model, 'use_compile', True):
            try:
                print("Using torch.compile to optimize Stage3 model...")
                self.model = torch.compile(self.model, mode='reduce-overhead')
            except Exception as e:
                print(f"torch.compile failed: {e}, continuing without compilation")

    def configure_optimizers(self):
        trainable_params = []

        freeze_cnn = getattr(self.cfg.model.vision, 'freeze_cnn', False)
        if not freeze_cnn:
            for param in self.model.vision_encoder.parameters():
                if param.requires_grad:
                    trainable_params.append(param)

        if self.model.queries.requires_grad:
            trainable_params.append(self.model.queries)

        for param in self.model.query_attention.parameters():
            if param.requires_grad:
                trainable_params.append(param)

        for param in self.model.vision_projection.parameters():
            if param.requires_grad:
                trainable_params.append(param)

        freeze_bert = getattr(self.cfg.model.text, 'freeze_bert', False)
        if not freeze_bert:
            for param in self.model.text_encoder.parameters():
                if param.requires_grad:
                    trainable_params.append(param)

        for param in self.model.text_projection.parameters():
            if param.requires_grad:
                trainable_params.append(param)

        if len(trainable_params) == 0:
            for param in self.model.vision_projection.parameters():
                trainable_params.append(param)
            for param in self.model.text_projection.parameters():
                trainable_params.append(param)

        class ParamWrapper:
            def __init__(self, params):
                self.params = params

            def parameters(self):
                return iter(self.params)

        param_wrapper = ParamWrapper(trainable_params)
        optimizer = builder.build_optimizer(self.cfg, self.lr, param_wrapper)
        scheduler = builder.build_scheduler(self.cfg, optimizer, self.dm)
        return {"optimizer": optimizer, "lr_scheduler": scheduler}

    def training_step(self, batch, batch_idx):
        loss = self.shared_step(batch, "train")
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self.shared_step(batch, "val")
        return loss

    def shared_step(self, batch, split):

        images = batch.get('images')
        input_ids = batch.get('input_ids')
        attention_mask = batch.get('attention_mask')
        token_type_ids = batch.get('token_type_ids')
        num_evidences = batch.get('num_evidences')
        paired_matrix = batch.get('paired_matrix')

        if images is None or input_ids is None:
            return torch.tensor(0.0, device=self.device, requires_grad=True)

        outputs = self.model(
            images=images,
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            num_evidences=num_evidences,
            paired_matrix=paired_matrix
        )

        infonce_loss = self.model.compute_weighted_infonce_loss(
            image_global_emb=outputs['image_global_emb'],
            text_global_emb=outputs['text_global_emb'],
            relation_matrix=outputs['relation_matrix']
        )


        total_loss = self.infonce_weight * infonce_loss

        log_iter = (split == "train")
        self.log(f"{split}_loss", total_loss, on_epoch=True, on_step=log_iter, logger=True, prog_bar=True)
        self.log(f"{split}_infonce_loss", infonce_loss, on_epoch=True, on_step=log_iter, logger=True)

        return total_loss

    def get_relation_matrix(self, batch):

        self.eval()
        with torch.no_grad():
            images = batch.get('images')
            input_ids = batch.get('input_ids')
            attention_mask = batch.get('attention_mask')
            token_type_ids = batch.get('token_type_ids')
            num_evidences = batch.get('num_evidences')
            paired_matrix = batch.get('paired_matrix')

            if images is None or input_ids is None:
                return None

            outputs = self.model(
                images=images,
                input_ids=input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids,
                num_evidences=num_evidences,
                paired_matrix=paired_matrix
            )
            return outputs['relation_matrix']

