import torch
import torch.nn as nn

import lightning.pytorch as L

class _MLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, dropout):
        super().__init__()

        self.model = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, output_dim),
        )

    def forward(self, x):
        return self.model(x)


class LogisticRegressionModule(L.LightningModule):
    def __init__(self, input_dim, hidden_dim, dropout, lr):
        super().__init__()
        self.save_hyperparameters()
        self.mlp = _MLP(input_dim, hidden_dim, 1, dropout)

    def forward(self, batch):
        return self.mlp(batch["x"]).squeeze(-1)
    
    def loss(self, batch, pred):
        return nn.functional.binary_cross_entropy_with_logits(pred, batch["y"])
    
    def training_step(self, batch, batch_idx):
        loss = self.loss(batch, self(batch))
        self.log(f'train', loss, on_step=False, on_epoch=True, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        loss = self.loss(batch, self(batch))
        self.log(f'train', loss, on_step=False, on_epoch=True, prog_bar=True)
        return loss
    
    def predict_step(self, batch, batch_idx, dataloader_idx=None):
        return self(batch)
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)


class LatentConformalLogisticRegression(L.LightningModule):
    def __init__(self, input_dim, hidden_dim, dropout, alpha, beta, oracle_z, lr):
        super().__init__()
        self.save_hyperparameters()

        self.mlp1 = _MLP(input_dim, hidden_dim, 1, dropout)
        self.mlp2 = _MLP(input_dim, hidden_dim, 1, dropout)
        self.mlp3 = _MLP(input_dim, hidden_dim, 1, dropout)

    def forward(self, batch):
        return {
            "z1": self.mlp1(batch["data1"]["x"]).squeeze(-1), 
            "z2": self.mlp2(batch["data2"]["x"]).squeeze(-1),
            "z1_3": self.mlp1(batch["data3"]["x"]).squeeze(-1), 
            "z2_3": self.mlp2(batch["data3"]["x"]).squeeze(-1), 
            "z3": self.mlp3(batch["data3"]["x"]).squeeze(-1),
        }
    
    def loss(self, batch, pred):
        l1 = nn.functional.binary_cross_entropy_with_logits(pred["z1"], batch["data1"]["y"])
        l2 = nn.functional.binary_cross_entropy_with_logits(pred["z2"], batch["data2"]["y"])
        l3 = nn.functional.binary_cross_entropy_with_logits(pred["z3"], batch["data3"]["y"])

        l1_3 = nn.functional.binary_cross_entropy_with_logits(pred["z1_3"], batch["data3"]["y"])
        l2_3 = nn.functional.binary_cross_entropy_with_logits(pred["z2_3"], batch["data3"]["y"])
 
        if self.hparams.oracle_z:
            z3 = batch["data3"]["z"]
        else:
            z3 = pred["z3"].detach()

        u1 = pred["z1_3"] - z3
        u2 = pred["z2_3"] - z3

        l_orth = (u1 * u2).mean() ** 2

        l_symm_1 = u1.mean() ** 2
        l_symm_2 = u2.mean() ** 2
        l_symm = l_symm_1 + l_symm_2

        loss = l1 + l2 + l3 + self.hparams.alpha * l_orth + self.hparams.beta * l_symm

        return {
            'l1': l1,
            'l2': l2,
            'l3': l3,
            'l1_3': l1_3,
            'l2_3': l2_3,
            'l_orth': l_orth,
            'l_symm_1': l_symm_1,
            'l_symm_2': l_symm_2,
            'loss': loss
        }
    
    def training_step(self, batch, batch_idx):
        loss = self.loss(batch, self(batch))
        for k, v in loss.items():
            self.log(f'train/{k}', v, on_step=False, on_epoch=True, prog_bar=(k == 'loss'))
        return loss["loss"]
    
    def validation_step(self, batch, batch_idx):
        loss = self.loss(batch, self(batch))
        for k, v in loss.items():
            self.log(f'val/{k}', v, on_step=False, on_epoch=True, prog_bar=(k == 'loss'))
        return loss["loss"]
    
    def predict_step(self, batch, batch_idx, dataloader_idx=None):
        if dataloader_idx is None:
            raise ValueError("dataloader_idx must be specified")
        elif dataloader_idx == 0:
            return {
                "z1_hat": self.mlp1(batch["x"]).squeeze(-1)
                }
        elif dataloader_idx == 1:
            return {
                "z2_hat": self.mlp2(batch["x"]).squeeze(-1)
                }
        elif dataloader_idx == 2:
            return {
                "z1_3_hat": self.mlp1(batch["x"]).squeeze(-1),
                "z2_3_hat": self.mlp2(batch["x"]).squeeze(-1),
                "z3_hat": self.mlp3(batch["x"]).squeeze(-1),
            }
        elif dataloader_idx == 3:
            return {
                "z1_4_hat": self.mlp1(batch["x"]).squeeze(-1),
                "z2_4_hat": self.mlp2(batch["x"]).squeeze(-1),
                "z3_4_hat": self.mlp3(batch["x"]).squeeze(-1),
            }
        else:
            raise ValueError("dataloader_idx must be 0, 1, or 2")
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
