import pytorch_lightning as pl
import torch, os
import torch.nn as nn
import torch.nn.functional as F
from models.multi_causal_gat_model import CausalAttentionRegressor

class LitGATCausalRegressor(pl.LightningModule):
    """
    LightningModule wrapping CausalAttentionRegressor.
    """
    def __init__(self,
                 mode: str,
                 max_atomic_num: int = 118,
                 n_peptide_types: int = None,
                 **kwargs):
        super().__init__()
        self.save_hyperparameters()
        self.mode = self.hparams.mode.upper()

        self.automatic_optimization = False

        # projection dimension (common hidden size for fusion)
        proj_dim = self.hparams.hidden_dim_projector      

        #Multimodality Projection
        self.project_smiles  = nn.Linear(self.hparams.hidden_dim_smiles,  proj_dim)
        self.project_peptide = nn.Linear(self.hparams.hidden_dim_peptide, proj_dim)
        self.project_geometry = nn.Linear(self.hparams.hidden_dim_geometry, proj_dim)

        self.tau_causal = nn.Parameter(torch.tensor(1.5))  
        self.gate_causal = nn.Sequential(
            nn.Linear(proj_dim * 3, proj_dim),
            nn.ReLU(),
            nn.Linear(proj_dim, 3)     
        )

        # Build sub-models based on mode
        if self.mode == 'FUSION':
            # Graph model for SMILES
            self.model_smiles = self._make_model(
                mode='SMILES',
                backbone='gat',
                max_types=max_atomic_num,
                prefix='smiles'
            )
            # Graph model for PEPTIDE
            self.model_peptide = self._make_model(
                mode='PEPTIDE',
                backbone='gat',
                max_types=n_peptide_types,
                prefix='peptide'
            )
            # Sequence model
            self.model_geometry = self._make_model(
                mode='GEOMETRY',
                backbone='egnn',
                max_types=max_atomic_num,
                prefix='geometry'
            )
        else:
            prefix = self.mode.lower()
            max_types = (max_atomic_num
                            if self.mode == 'SMILES' or self.mode == 'GEOMETRY'
                            else n_peptide_types)
            # Single-modality graph model
            backbone = 'gat' if self.mode in ('SMILES', 'PEPTIDE') else 'egnn'
            self.model = self._make_model(
                backbone=backbone,
                max_types=max_types,
                prefix=prefix
            )
        # choose hidden dimension based on mode
        if self.mode == "SMILES":
            in_dim = self.hparams.hidden_dim_smiles
        elif self.mode == "PEPTIDE":
            in_dim = self.hparams.hidden_dim_peptide
        elif self.mode == "GEOMETRY":
            in_dim = self.hparams.hidden_dim_geometry
        else:
            in_dim = 1


        #Causal head
        self.layer_heads = nn.ModuleList([nn.Linear(proj_dim, 1) for _ in range(self.hparams.num_causal_blocks)])
        self.ln = nn.LayerNorm(proj_dim)

        ##Trivial head
        self.head_smiles  = nn.ModuleList([nn.Linear(proj_dim, 1) for _ in range(self.hparams.num_causal_blocks)]) 
        self.head_geo     = nn.ModuleList([nn.Linear(proj_dim, 1) for _ in range(self.hparams.num_causal_blocks)])
        self.head_peptide = nn.ModuleList([nn.Linear(proj_dim, 1) for _ in range(self.hparams.num_causal_blocks)])


        #Unimodal parameters
        self.projection_unimodal_causal = nn.Linear(in_dim, proj_dim)
        self.projection_unimodal_trivial = nn.Linear(in_dim, proj_dim)
        self.unimodal_heads = nn.ModuleList([nn.Linear(proj_dim, 1) for _ in range(self.hparams.num_causal_blocks)]) #Trivial head


    def _make_model(self, backbone: int, max_types: int, prefix: str, mode: str = None):  
        """
        Create a GATCausalAttentionRegressor.
        `prefix` is either 'smiles' or 'peptide', used to select the
        corresponding hyperparameters from self.hparams.
        """
        h = self.hparams
        return CausalAttentionRegressor(
            mode           = mode,  # 'SMILES', 'PEPTIDE', 'GEOMETRY'
            backbone       = backbone,
            max_atomic_num = max_types,
            emb_dim        = getattr(h, f'emb_dim_{prefix}'),
            hidden_dim     = getattr(h, f'hidden_dim_{prefix}'),
            num_backbone_layers  = getattr(h, f'num_gc_layers_{prefix}'),
            num_causal_blocks  = h.num_causal_blocks,
            dropout        = h.dropout,
            heads          = getattr(h, f'heads_{prefix}'),
            lambda_unif    = h.lambda_unif,
            lambda_caus    = h.lambda_caus,
        )


    def fusion_loss(
        self,
        y_c_all,              # list of [B] OR tensor [B, L]  (causal per-layer outputs)
        y_t,                  # [B]  trivial correction
        y_true: torch.Tensor,            # [B]
        rho_min: float = 0.5,
        rho_max: float = 0.8,   #0.8
        eps: float = 1e-8,
    ) -> torch.Tensor:
        if isinstance(y_c_all, (list, tuple)):
            # list of [B] -> [L,B] -> [B,L]
            y_c_all = torch.stack(y_c_all, dim=0).transpose(0, 1)
        B, L = y_c_all.shape

        y_true = y_true.to(dtype=y_c_all.dtype)
        y_t    = y_t.to(dtype=y_c_all.dtype)

        # 1) Causal: per-layer Pearson corr to targets
        # center
        yc_mean   = y_c_all.mean(dim=0, keepdim=True)              # [1, L]
        yc_center = y_c_all - yc_mean                              # [B, L]
        yt_center = y_true - y_true.mean()                         # [B]

        num = (yc_center * yt_center.unsqueeze(1)).sum(dim=0)      # [L]
        den = (
            yc_center.pow(2).sum(dim=0).clamp_min(eps).sqrt()
            * yt_center.pow(2).sum().clamp_min(eps).sqrt()
        )                                                          # [L]
        corrs = (num / den).clamp(-1.0, 1.0)                       # [L]

        rho_targets = torch.linspace(rho_min, rho_max, steps=L,
                                    device=y_c_all.device, dtype=y_c_all.dtype)  # [L]
        loss_corr = (corrs - rho_targets).pow(2).mean()

        if L >= 2:
            mono_margin = 0.0
            diffs = corrs[:-1] - corrs[1:] + mono_margin           # [L-1]
            mono_penalty = F.relu(diffs).mean()
        else:
            mono_penalty = y_c_all.new_tensor(0.0)

        # 2) Trivial: fit residual of causal
        y_c_base = y_c_all[:, -1]                                   # [B] #Last layer

        residual = (y_true - y_c_base).detach()                     # stop-grad through causal
        loss_trivial = F.mse_loss(y_t, residual)                   
        y_final = y_c_base + y_t                                    # [B]
        loss_final = F.mse_loss(y_final, y_true)

        total_loss = (
            loss_final
            + self.hparams.lambda_caus * loss_corr
            + self.hparams.lambda_unif * loss_trivial
            + self.hparams.lambda_mono * mono_penalty
        )
        return total_loss


    def forward(self, data):
        if self.mode == 'FUSION':
            mol, pep, geo = data

            # 1) embeddings
            _, _, z_c_sm_layers, z_t_sm_layers = self.model_smiles(mol)
            _, _, z_c_pe_layers, z_t_pe_layers = self.model_peptide(pep)
            _, _, z_c_ge_layers, z_t_ge_layers = self.model_geometry(geo)

            # print('z_c_sm_layers.shape',len(z_c_sm_layers),z_c_sm_layers[0].shape)

            L = len(z_t_sm_layers)

            # 2) causal projections
            z_c_sm_p_all = [self.ln(self.project_smiles(z)) for z in z_c_sm_layers]  # list of [B,D]
            z_c_pe_p_all = [self.ln(self.project_peptide(z)) for z in z_c_pe_layers]
            z_c_ge_p_all = [self.ln(self.project_geometry(z)) for z in z_c_ge_layers] # list of [B,D]
            
            cons_loss  = 0.0                              

            per_layer_scalar = []

            for l in range(L):
                z_sm, z_pe, z_ge = z_c_sm_p_all[l], z_c_pe_p_all[l], z_c_ge_p_all[l]  # [B,D] each

                # ----- gating -----
                z_concat = torch.cat([z_sm, z_pe, z_ge], dim=-1)   # [B, 3D]
                logits = self.gate_causal(z_concat)                # [B, 3]

                # per-sample standardization (stable gating like your code)
                mu = logits.mean(dim=-1, keepdim=True)
                sd = logits.std(dim=-1, keepdim=True).clamp_min(1e-3)
                logits = (logits - mu) / sd

                # temperature
                tau = torch.clamp(self.tau_causal, 0.6, 2.5) 
                w = F.softmax(logits / tau, dim=-1)                # [B, 3]

                # ----- weighted fuse -----
                z_stack = torch.stack([z_sm, z_pe, z_ge], dim=1)   # [B,3,D]
                z_fused = (w.unsqueeze(-1) * z_stack).sum(dim=1)   # [B,D]

                s_l = self.layer_heads[l](z_fused)                 # [B,1]
                per_layer_scalar.append(s_l)

            cons_loss  = cons_loss  / L                   # ***

            # stack scalars across layers -> [L,B,1] -> transpose to [B,L]
            y_c_all = torch.stack(per_layer_scalar, dim=0).squeeze(-1).transpose(0, 1)  # [B,L]

            # 3) trivial projections
            z_t_sm_p_all = [self.ln(self.project_smiles(z)) for z in z_t_sm_layers]  # list [B,D]
            z_t_pe_p_all = [self.ln(self.project_peptide(z)) for z in z_t_pe_layers]  # list [B,D]
            z_t_ge_p_all = [self.ln(self.project_geometry(z)) for z in z_t_ge_layers] # list [B,D]

            per_layer_sum = []
            for l in range(L):
                s = self.head_smiles[l](z_t_sm_p_all[l])   # [B,1]
                p = self.head_peptide[l](z_t_pe_p_all[l])
                g = self.head_geo[l](z_t_ge_p_all[l])      # [B,1]
                per_layer_sum.append(s + p + g)       

            y_t = torch.stack(per_layer_sum, dim=0).sum(dim=0).squeeze(-1)  

            # print('w:',w[:5])

            return y_c_all, y_t, w, cons_loss

        else:
            # Forward backbone -> per-layer causal/trivial features (list of [B, D])
            _, _, z_c_layers, z_t_layers = self.model(data)

            # Project causal and trivial layers -> list of [B, D]
            z_c_proj = [self.ln(self.projection_unimodal_causal(z)) for z in z_c_layers]
            z_t_proj = [self.ln(self.projection_unimodal_trivial(z)) for z in z_t_layers]

            L = len(z_c_proj)
            per_layer_scalar = []

            # Per-layer causal prediction -> each is [B, 1]
            for l in range(L):
                s_l = self.layer_heads[l](z_c_proj[l])
                per_layer_scalar.append(s_l)

            # Stack per-layer causal outputs -> [B, L]
            y_c_all = torch.cat(per_layer_scalar, dim=1)

            # Per-layer trivial prediction -> each is [B, 1]
            per_layer_sum = []
            for l in range(L):
                s = self.unimodal_heads[l](z_t_proj[l])
                per_layer_sum.append(s)

            # Sum across layers -> [B, 1]
            y_t = torch.stack(per_layer_sum, dim=0).sum(dim=0).squeeze(-1)

            return y_c_all, y_t




    def training_step(self, batch, batch_idx):

        if self.mode == 'FUSION':
            (mol, pep, geo), y = batch
            y_c_all, y_t,  _, cons_loss  = self((mol, pep, geo))
            loss = self.fusion_loss(y_c_all, y_t, mol.y) + self.hparams.w_cons_loss * cons_loss

        else:
            y_c_all, y_t = self(batch)

            # print('y_c_all',len(y_c_all),y_c_all[0].shape)
            # print('y_t',len(y_t),y_t[0].shape)
            # print('batch.y',batch.y.shape)

            loss = self.fusion_loss(y_c_all, y_t, batch.y)

        labels = batch.y if self.mode != 'FUSION' else mol.y  # Get labels from batch

        self.log('train_loss', loss, on_epoch=True, prog_bar=True, batch_size=labels.size(0), sync_dist=True)
        

        # manual optimization
        opts = self.optimizers()
        schedulers = self.lr_schedulers()
        if not isinstance(opts, (list, tuple)):
            opts = [opts]

        if not isinstance(schedulers, (list, tuple)):
            schedulers = [schedulers]

        for opt in opts:
            opt.zero_grad(set_to_none=True)

        self.manual_backward(loss)

        for opt in opts:
            opt.step()

        # step schedulers
        for sch in schedulers:
            sch.step()
        


    def on_validation_start(self):
        self._val_losses = []


    def validation_step(self, batch, batch_idx):

        if self.mode == 'FUSION':
            (mol, pep, geo), y = batch
            y_c_all, y_t,  _, _ = self((mol, pep, geo))

            y_c_base = y_c_all[:, -1]  
            # y_c_base = y_c_all.mean(dim=1)####average layers

            y_final = y_t + y_c_base                       # [B]
            # y_final = y_c_base  #w/o trivial branch
            loss = F.mse_loss(y_final, mol.y)
        
        else:
            y_c_all, y_t = self(batch)
            
            y_c_base = y_c_all[:, -1]  
            # y_c_base = y_c_all.mean(dim=1)####average layers

            y_final = y_t + y_c_base                       # [B]\
            # y_final = y_c_base  #w/o trivial branch
            loss = F.mse_loss(y_final, batch.y)

            # print('y_c_base',y_c_base.shape)

        self._val_losses.append(loss.detach())


    def on_validation_epoch_end(self):
        losses_local = torch.stack(self._val_losses).to(self.device) 

        if self.trainer.world_size > 1:
            losses_all = self.all_gather(losses_local)          # [world_size, N_local]
            mean_loss  = losses_all.view(-1).mean()             # [N_total] → scalar
        else:
            mean_loss  = losses_local.mean()

        self.log('val_loss', mean_loss, prog_bar=True)

        self._val_losses.clear()

    def on_test_start(self):
        self._test_preds = []
        self._test_trues = []

        self._test_cor_1 = []
        self._test_cor_2 = []
        self._test_cor_3 = []
        self._test_cor_4 = []
        self._test_cor_5 = []

    def test_step(self, batch, batch_idx):

        if self.mode == 'FUSION':
            (mol, pep, geo), y = batch
            y_c_all, y_t,  _, _ = self((mol, pep, geo))

        else:
            y_c_all, y_t = self(batch)

        labels = batch.y if self.mode != 'FUSION' else mol.y  # Get labels from batch

        # y_t_all_tensor = torch.stack(y_t_all, dim=1)   # [B, L]
        y_c_base = y_c_all[:, -1]  
        # y_c_base = y_c_all.mean(dim=1)####average layers

        y_final = y_t + y_c_base                       # [B]
        # y_final = y_c_base  #w/o trivial branch
        # loss = F.mse_loss(y_final, mol.y)

        self._test_preds.append((y_final).detach())
        self._test_trues.append(labels.detach())

        self._test_cor_1.append(y_c_all[:, 0].detach())
        self._test_cor_2.append(y_c_all[:, 1].detach())
        self._test_cor_3.append(y_c_all[:, 2].detach())
        self._test_cor_4.append(y_c_all[:, 3].detach())
        self._test_cor_5.append(y_c_all[:, 4].detach())
        
    def on_test_epoch_end(self):

        def _pearsonr(x: torch.Tensor, y: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
            x = x.float().view(-1)
            y = y.float().view(-1)
            x = x - x.mean()
            y = y - y.mean()
            denom = torch.sqrt((x * x).sum()) * torch.sqrt((y * y).sum())
            return (x * y).sum() / (denom + eps)


        preds_local = torch.cat(self._test_preds, dim=0)  # [N_local]
        trues_local = torch.cat(self._test_trues, dim=0)  # [N_local]

        local_cor_1 = torch.cat(self._test_cor_1, dim=0)
        local_cor_2 = torch.cat(self._test_cor_2, dim=0)
        local_cor_3 = torch.cat(self._test_cor_3, dim=0)
        local_cor_4 = torch.cat(self._test_cor_4, dim=0)
        local_cor_5 = torch.cat(self._test_cor_5, dim=0)


        if self.trainer.world_size > 1:
            preds_all = self.all_gather(preds_local)  # [world_size, N_local]
            trues_all = self.all_gather(trues_local)

            cor_1 = self.all_gather(local_cor_1).view(-1)
            cor_2 = self.all_gather(local_cor_2).view(-1)
            cor_3 = self.all_gather(local_cor_3).view(-1)
            cor_4 = self.all_gather(local_cor_4).view(-1)
            cor_5 = self.all_gather(local_cor_5).view(-1)

            preds = preds_all.view(-1)                # [N_total]
            trues = trues_all.view(-1)
        else:
            preds = preds_local
            trues = trues_local

            cor_1 = local_cor_1
            cor_2 = local_cor_2
            cor_3 = local_cor_3
            cor_4 = local_cor_4
            cor_5 = local_cor_5

        corr1 = _pearsonr(cor_1, trues)
        corr2 = _pearsonr(cor_2, trues)
        corr3 = _pearsonr(cor_3, trues)
        corr4 = _pearsonr(cor_4, trues)
        corr5 = _pearsonr(cor_5, trues)

        self.log('test_corr_layer1', corr1, on_epoch=True, prog_bar=True)
        self.log('test_corr_layer2', corr2, on_epoch=True, prog_bar=True)
        self.log('test_corr_layer3', corr3, on_epoch=True, prog_bar=True)
        self.log('test_corr_layer4', corr4, on_epoch=True, prog_bar=True)
        self.log('test_corr_layer5', corr5, on_epoch=True, prog_bar=True)

        corrs = torch.stack([corr1, corr2, corr3, corr4, corr5]).detach().cpu().tolist()
        self.print("Correlation:", [f"{v:.4f}" for v in corrs])

        ss_res = torch.sum((trues - preds) ** 2)
        ss_tot = torch.sum((trues - trues.mean()) ** 2)
        mse = F.mse_loss(preds, trues)
        mae = F.l1_loss(preds, trues)
        r2  = 1 - ss_res / ss_tot

        self.log('test_mse', mse, on_epoch=True, prog_bar=True)
        self.log('test_mae', mae, on_epoch=True, prog_bar=True)
        self.log('test_r2',  r2,  on_epoch=True, prog_bar=True)

        self._test_preds.clear()
        self._test_trues.clear()



        self._test_cor_1.clear()
        self._test_cor_2.clear()
        self._test_cor_3.clear()
        self._test_cor_4.clear()
        self._test_cor_5.clear()


    def configure_optimizers(self):
        h = self.hparams

        if self.mode == 'FUSION':
            # create one optimizer per sub-model
            opt_sm = torch.optim.Adam(self.model_smiles.parameters(),   lr=h.lr_smiles)
            opt_pe = torch.optim.Adam(self.model_peptide.parameters(),  lr=h.lr_peptide)
            opt_sq = torch.optim.Adam(self.model_geometry.parameters(), lr=h.lr_geometry)

            # collect projector + regressor parameters for head optimizer
            backbone_ids = {
                id(p) for m in (self.model_smiles, self.model_peptide, self.model_geometry)
                for p in m.parameters()
            }
            head_params = [p for p in self.parameters() if id(p) not in backbone_ids]
            opt_hd = torch.optim.Adam(head_params, lr=h.lr)

            # attach a scheduler to each optimizer
            if h.scheduler_type == 'cosine':
                sched_sm = torch.optim.lr_scheduler.CosineAnnealingLR(opt_sm, T_max=self.trainer.max_epochs, eta_min=h.lr_smiles * 0.01)
                sched_pe = torch.optim.lr_scheduler.CosineAnnealingLR(opt_pe, T_max=self.trainer.max_epochs, eta_min=h.lr_peptide * 0.01)
                sched_sq = torch.optim.lr_scheduler.CosineAnnealingLR(opt_sq, T_max=self.trainer.max_epochs, eta_min=h.lr_geometry * 0.01)
                sched_hd = torch.optim.lr_scheduler.CosineAnnealingLR(opt_hd, T_max=self.trainer.max_epochs, eta_min=h.lr * 0.01)
            else:
                # example: step scheduler
                sched_sm = torch.optim.lr_scheduler.StepLR(opt_sm, step_size=10, gamma=0.5)
                sched_pe = torch.optim.lr_scheduler.StepLR(opt_pe, step_size=10, gamma=0.5)
                sched_sq = torch.optim.lr_scheduler.StepLR(opt_sq, step_size=10, gamma=0.5)
                sched_hd = torch.optim.lr_scheduler.StepLR(opt_hd, step_size=10, gamma=0.5)

            # return lists of optimizers and schedulers
            return [opt_sm, opt_pe, opt_sq, opt_hd], [sched_sm, sched_pe, sched_sq, sched_hd]

        # --- single-modality: always return lists, even for plateau ---
        optimizer = torch.optim.Adam(self.parameters(), lr=h.lr)
        if h.scheduler_type == 'plateau':
            scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer, mode='min', factor=0.5, patience=20, verbose=True
            )
        elif h.scheduler_type == 'cosine':
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer, T_max=self.trainer.max_epochs, eta_min=h.lr * 0.01
            )
        else:
            scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
                optimizer, T_0=10, T_mult=2, eta_min=h.lr * 0.01
            )


        return [optimizer], [scheduler]
