import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np

# Define the encoder network with GRU and additional layers
class CPCEncoder(nn.Module):
    def __init__(self, input_dim, latent_dim, hidden_dim):
        super(CPCEncoder, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.bn1 = nn.BatchNorm1d(hidden_dim)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.bn2 = nn.BatchNorm1d(hidden_dim)

    def forward(self, x):
        x = self.bn1(self.relu(self.fc1(x)))
        x = self.bn2(self.relu(self.fc2(x)))
        return x

# Define the predictor network
class CPCPredictor(nn.Module):
    def __init__(self, latent_dim, hidden_dim):
        super(CPCPredictor, self).__init__()
        self.predictor = nn.Linear(latent_dim, latent_dim)

    def forward(self, x):
        return self.predictor(x)

# Define the CPC model
class CPC(nn.Module):
    def __init__(self, input_dim, latent_dim, hidden_dim):
        super(CPC, self).__init__()
        self.encoder = CPCEncoder(input_dim, latent_dim, hidden_dim)
        self.gru = nn.GRU(input_size=hidden_dim, hidden_size=hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, latent_dim)

    def forward(self, x):
        x = self.encoder(x)
        _, h_n = self.gru(x)
        h_n = h_n.squeeze(0)
        encoded = self.fc(h_n)
        return encoded

# Contrastive loss function
def contrastive_loss(x, y, temperature=0.07):
    x = nn.functional.normalize(x, dim=-1)
    y = nn.functional.normalize(y, dim=-1)
    logits = torch.matmul(x, y.t()) / temperature
    labels = torch.arange(x.size(0)).to(x.device)
    loss = nn.CrossEntropyLoss()(logits, labels)
    return loss

