import torch
import torch.nn as nn

import lightning.pytorch as L

class BradeleyTerryModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, dropout):
        super().__init__()

        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 1),
        )

    def forward(self, batch):
        return (self.net(batch["x1"]) - self.net(batch["x2"])).squeeze(-1)
    
    def loss(self, pred, batch):
        return nn.functional.binary_cross_entropy_with_logits(pred, batch["y"])


class LatentConformalBradleyTerry(L.LightningModule):
    def __init__(self, input_dim, hidden_dim, dropout, alpha, beta, lr):
        super().__init__()
        self.save_hyperparameters()

        self.bt1 = BradeleyTerryModel(input_dim, hidden_dim, dropout)
        self.bt2 = BradeleyTerryModel(input_dim, hidden_dim, dropout)
        self.bt3 = BradeleyTerryModel(input_dim, hidden_dim, dropout)

    def forward(self, batch):
        return {
            "z1": self.bt1(batch["data1"]), 
            "z2": self.bt2(batch["data2"]),
            "z1_3": self.bt1(batch["data3"]), 
            "z2_3": self.bt2(batch["data3"]), 
            "z3": self.bt3(batch["data3"]),
        }
    
    def loss(self, batch, pred):
        l1 = self.bt1.loss(pred["z1"], batch["data1"])
        l2 = self.bt2.loss(pred["z2"], batch["data2"])
        l3 = self.bt3.loss(pred["z3"], batch["data3"])
 
        z1 = pred["z1_3"]
        z2 = pred["z2_3"]
        z3 = pred["z3"].detach()

        u1 = z1 - z3
        u2 = z2 - z3

        l_orth = (u1 * u2).mean() ** 2

        l_symm_1 = u1.mean() ** 2
        l_symm_2 = u2.mean() ** 2
        l_symm = l_symm_1 + l_symm_2

        loss = l1 + l2 + l3 + self.hparams.alpha * l_orth + self.hparams.beta * l_symm

        return {
            'l1': l1,
            'l2': l2,
            'l3': l3,
            'l_orth': l_orth,
            'l_symm_1': l_symm_1,
            'l_symm_2': l_symm_2,
            'loss': loss
        }
    
    def training_step(self, batch, batch_idx):
        loss = self.loss(batch, self(batch))
        for k, v in loss.items():
            self.log(f'train/{k}', v, on_step=False, on_epoch=True, prog_bar=(k == 'loss'))
        return loss["loss"]
    
    def validation_step(self, batch, batch_idx):
        loss = self.loss(batch, self(batch))
        for k, v in loss.items():
            self.log(f'val/{k}', v, on_step=False, on_epoch=True, prog_bar=(k == 'loss'))
        return loss["loss"]
    
    def predict_step(self, batch, batch_idx, dataloader_idx=None):
        if dataloader_idx is None:
            raise ValueError("dataloader_idx must be specified")
        elif dataloader_idx == 0:
            return {
                "z1_hat": self.bt1(batch).squeeze(-1)
                }
        elif dataloader_idx == 1:
            return {
                "z2_hat": self.bt2(batch).squeeze(-1)
                }
        elif dataloader_idx == 2:
            return {
                "z1_3_hat": self.bt1(batch).squeeze(-1),
                "z2_3_hat": self.bt2(batch).squeeze(-1),
                "z3_hat": self.bt3(batch).squeeze(-1),
            }
        else:
            raise ValueError("dataloader_idx must be 0, 1, or 2")
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
