
import torch
from torch import nn
import lightning.pytorch as pl
from ..metrics.correlation import compute_r2
import torch.nn.functional as F
from .mix import MixingMLP, MixingCNN, MixingKP, ScoringFunc
import numpy as np
import ipdb as pdb

class PCL(pl.LightningModule):

    def __init__(self,
                 input_dim, 
                 z_dim, 
                 lags=2, 
                 hidden_dims=64, 
                 encoder_layers=3, 
                 scoring_layers=3,
                 correlation='Pearson',
                 lr=0.001):
        super().__init__()
        self.input_dim=input_dim
        self.z_dim = z_dim
        self.z1_dim = self.z2_dim = self.z3_dim = self.z4_dim = int(z_dim/4)
        self.z_dim_true_list = [2, 2, 2, 2]
        self.L = lags
        self.lr = lr
        self.encoder = MixingMLP(input_dims=input_dim,
                                 z_dim=self.z_dim,
                                 num_layers=encoder_layers, 
                                 negative_slope=0.2)

        self.scoring_funcs = nn.ModuleList([
            ScoringFunc(input_dims=lags+1, 
                        hidden_dims=hidden_dims, 
                        num_layers=scoring_layers) for _ in range(z_dim)]
            )

        self.loss_func= F.binary_cross_entropy_with_logits
        self.correlation = correlation

    def forward(self, x):
        return self.encoder(x)
    
    def training_step(self, batch, batch_idx):
        # x_pos: [BS, L+1 , D]
        x_pos, x_neg = batch['pos']['x'], batch['neg']['x']
        x_pos = x_pos.view(-1, self.L+1, self.input_dim)
        x_neg = x_neg.view(-1, self.L+1, self.input_dim)
        batch_size = x_pos.shape[0]
        cat = torch.cat((x_pos, x_neg), dim=0)
        embeddings = self.encoder(cat) # [2BS, L+1 , D]
        # embeddings of shape BS X sources X contrastive_samples
        scores = 0
        for i in range(self.z_dim):
            embedding = embeddings[:,:,i]
            score = self.scoring_funcs[i](embedding)
            scores = scores + score
        scores = scores.squeeze()
        ones = torch.ones(batch_size, device=x_pos.device)
        zeros = torch.zeros(batch_size, device=x_pos.device)
        loss = 0.5 * (self.loss_func(scores[:batch_size], ones) + self.loss_func(scores[batch_size:], zeros))
        self.log("train_loss", loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        if "yt" in batch['s1']:
            x = batch['s1']['xt']
            embeddings = self.encoder(x)
            zt_recon = embeddings.view(-1, self.z_dim).detach().cpu().numpy()
            # print('recon', zt_recon.shape)
            # Compute R2
            z1_recon, z2_recon, z3_recon, z4_recon =  np.split(zt_recon, [self.z1_dim, self.z1_dim+self.z2_dim, self.z1_dim+self.z2_dim+self.z3_dim], axis=-1)
            train_z1_recon, test_z1_recon = np.split(z1_recon, 2, axis=0)
            train_z2_recon, test_z2_recon = np.split(z2_recon, 2, axis=0)
            train_z3_recon, test_z3_recon = np.split(z3_recon, 2, axis=0)
            train_z4_recon, test_z4_recon = np.split(z4_recon, 2, axis=0)
            zt_true = batch["s1"]["yt"].view(-1, sum(self.z_dim_true_list)).detach().cpu().numpy()
            # print('true', zt_true.shape)
            z1_true, z2_true, z3_true, z4_true =  np.split(zt_true, [self.z_dim_true_list[0], self.z_dim_true_list[0]+self.z_dim_true_list[1], self.z_dim_true_list[0]+self.z_dim_true_list[1]+self.z_dim_true_list[2]], axis=-1)
            train_z1_true, test_z1_true = np.split(z1_true, 2, axis=0)
            train_z2_true, test_z2_true = np.split(z2_true, 2, axis=0)
            train_z3_true, test_z3_true = np.split(z3_true, 2, axis=0)
            train_z4_true, test_z4_true = np.split(z4_true, 2, axis=0)
            r21 = compute_r2(train_z1_recon, train_z1_true, test_z1_recon, test_z1_true)
            r22 = compute_r2(train_z2_recon, train_z2_true, test_z2_recon, test_z2_true)
            r23 = compute_r2(train_z3_recon, train_z3_true, test_z3_recon, test_z3_true)
            r24 = compute_r2(train_z4_recon, train_z4_true, test_z4_recon, test_z4_true)
            ave_r2 = (r21 + r22 + r23 + r24) / 4.0
            r21h = compute_r2(train_z1_true, train_z1_recon, test_z1_true, test_z1_recon)
            r22h = compute_r2(train_z2_true, train_z2_recon, test_z2_true, test_z2_recon)
            r23h = compute_r2(train_z3_true, train_z3_recon, test_z3_true, test_z3_recon)
            r24h = compute_r2(train_z4_true, train_z4_recon, test_z4_true, test_z4_recon)
            ave_r2h = (r21h + r22h + r23h + r24h) / 4.0
                
            self.log("r21", r21)
            self.log("r22", r22)
            self.log("r23", r23)
            self.log("r24", r24)
            self.log("ave_r2", ave_r2)
            self.log("r21h", r21h)
            self.log("r22h", r22h)
            self.log("r23h", r23h)
            self.log("r24h", r24h)
            self.log("ave_r2h", ave_r2h)
        else:
            x_pos, x_neg = batch['pos']['x'], batch['neg']['x']
            x_pos = x_pos.view(-1, self.L+1, self.input_dim)
            x_neg = x_neg.view(-1, self.L+1, self.input_dim)
            batch_size = x_pos.shape[0]
            cat = torch.cat((x_pos, x_neg), dim=0)
            embeddings = self.encoder(cat) # [2BS, L+1 , D]
            # embeddings of shape BS X sources X contrastive_samples
            scores = 0
            for i in range(self.z_dim):
                embedding = embeddings[:,:,i]
                score = self.scoring_funcs[i](embedding)
                scores = scores + score
            scores = scores.squeeze()
            ones = torch.ones(batch_size, device=x_pos.device)
            zeros = torch.zeros(batch_size, device=x_pos.device)
            loss = 0.5 * (self.loss_func(scores[:batch_size], ones) + self.loss_func(scores[batch_size:], zeros))
            self.log("val_loss", loss) 
        
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer