import torch
import torch.nn as nn

import lightning.pytorch as L

class LinearBradeleyTerryModel(nn.Module):
    def __init__(self, input_dim, initial_weight):
        super().__init__()
        self.linear = nn.Linear(input_dim, 1)

        if initial_weight is not None:
            with torch.no_grad():
                self.linear.weight.copy_(torch.tensor(initial_weight).reshape(1, input_dim))

        self.net = nn.Sequential(
            self.linear,
            nn.Dropout(0.1),
        )

    def forward(self, batch):
        return (self.net(batch["x1"]) - self.net(batch["x2"])).squeeze(-1)
    
    def loss(self, pred, batch):
        return nn.functional.binary_cross_entropy_with_logits(pred, batch["y"])


class LLMLatentConformalBradleyTerry(L.LightningModule):
    def __init__(self, input_dim, alpha, beta, lr, initial_weight=None, oracle_z=False):
        super().__init__()
        self.save_hyperparameters()

        self.bt1 = LinearBradeleyTerryModel(input_dim, initial_weight)
        self.bt2 = LinearBradeleyTerryModel(input_dim, initial_weight)
        self.bt3 = LinearBradeleyTerryModel(input_dim, initial_weight)

    def forward(self, batch):
        return {
            "z1": self.bt1(batch["data1"]), 
            "z2": self.bt2(batch["data2"]),
            "z1_3": self.bt1(batch["data3"]), 
            "z2_3": self.bt2(batch["data3"]), 
            "z3": self.bt3(batch["data3"]),
        }
    
    def loss(self, batch, pred):
        l1 = self.bt1.loss(pred["z1"], batch["data1"])
        l2 = self.bt2.loss(pred["z2"], batch["data2"])
        l3 = self.bt3.loss(pred["z3"], batch["data3"])

        l3_1 = self.bt3.loss(pred["z3"], batch["data1"])
        l3_2 = self.bt3.loss(pred["z3"], batch["data2"])

        l1_3 = self.bt1.loss(pred["z1_3"], batch["data3"])
        l2_3 = self.bt2.loss(pred["z2_3"], batch["data3"])
 
        z1 = pred["z1_3"]
        z2 = pred["z2_3"]
        z3 = batch["data3"]["z"] if self.hparams.oracle_z else pred["z3"].detach()

        u1 = z1 - z3
        u2 = z2 - z3

        l_orth = (u1 * u2).mean() ** 2

        l_symm_1 = u1.mean() ** 2
        l_symm_2 = u2.mean() ** 2
        l_symm = l_symm_1 + l_symm_2

        loss = l1 + l2 + l3 + self.hparams.alpha * l_orth + self.hparams.beta * l_symm

        return {
            'l1': l1,
            'l2': l2,
            'l3': l3,
            'l1_3': l1_3,
            'l2_3': l2_3,
            'l3_1': l3_1,
            'l_orth': l_orth,
            'l_symm_1': l_symm_1,
            'l_symm_2': l_symm_2,
            'loss': loss
        }
    
    def training_step(self, batch, batch_idx):
        loss = self.loss(batch, self(batch))
        for k, v in loss.items():
            self.log(f'train/{k}', v, on_step=False, on_epoch=True, prog_bar=(k == 'loss'))
        return loss["loss"]
    
    def validation_step(self, batch, batch_idx):
        loss = self.loss(batch, self(batch))
        for k, v in loss.items():
            self.log(f'val/{k}', v, on_step=False, on_epoch=True, prog_bar=(k == 'loss'))
        return loss["loss"]
    
    def predict_step(self, batch, batch_idx, dataloader_idx=None):
        if dataloader_idx is None:
            raise ValueError("dataloader_idx must be specified")
        elif dataloader_idx == 0:
            return {
                "z1_hat": self.bt1(batch)
                }
        elif dataloader_idx == 1:
            return {
                "z2_hat": self.bt2(batch)
                }
        elif dataloader_idx == 2:
            return {
                "z1_3_hat": self.bt1(batch),
                "z2_3_hat": self.bt2(batch),
                "z3_hat": self.bt3(batch),
            }
        elif dataloader_idx == 3:
            return {
                "z1_4_hat": self.bt1(batch),
                "z2_4_hat": self.bt2(batch),
                "z3_4_hat": self.bt3(batch),
            }
        else:
            raise ValueError("dataloader_idx must be 0, 1, or 2")
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
