import torch
import torch.nn as nn
import torch.optim as optim
from transformers import CLIPModel, CLIPProcessor
from torch.optim import Adam
from torch.nn import CrossEntropyLoss

class TemporalAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(TemporalAttention, self).__init__()
        self.heads = heads
        self.embed_size = embed_size
        self.head_dim = embed_size // heads

        assert (
            self.head_dim * heads == embed_size
        ), "Embedding size needs to be divisible by heads"

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

    def forward(self, value, key, query):
        N = query.shape[0]
        value_len, key_len, query_len = value.shape[1], key.shape[1], query.shape[1]

        # Split the embedding into multiple heads
        values = value.reshape(N, value_len, self.heads, self.head_dim)
        keys = key.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)

        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)

        # Attention mechanism
        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)
        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )

        out = self.fc_out(out)
        return out


class SpatialAttention(TemporalAttention):
    def __init__(self, embed_size, heads):
        super(SpatialAttention, self).__init__(embed_size, heads)
        
    def forward(self, value, key, query):
        N = query.shape[0]
        value_len, key_len, query_len = value.shape[1], key.shape[1], query.shape[1]
        values = value.reshape(N, value_len, self.heads, self.head_dim)
        keys = key.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)

        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)
        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)
        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )
        out = self.fc_out(out)
        return out

class TextVideoRetrievalModel(nn.Module):
    def __init__(self, clip_model_name='ViT-L/14', embed_size=512, heads=8):
        super(TextVideoRetrievalModel, self).__init__()
        self.clip_model = CLIPModel.from_pretrained(clip_model_name)
        self.temporal_attention = TemporalAttention(embed_size, heads)
        self.spatial_attention = SpatialAttention(embed_size, heads)

    def forward(self, textual_input, video_frames):
        frame_embeddings = []  # To store embeddings for each frame
        for frame in video_frames:
            frame_embedding = self.clip_model.get_image_features(frame)
            frame_embeddings.append(frame_embedding)

        video_embeddings = torch.stack(frame_embeddings, dim=1)  # Shape: (batch_size, num_frames, embed_size)
        text_embeddings = self.clip_model.get_text_features(textual_input)  # Shape: (batch_size, embed_size)
        temporal_features = self.temporal_attention(video_embeddings, video_embeddings, video_embeddings)
        spatial_features = self.spatial_attention(video_embeddings, video_embeddings, video_embeddings)
        combined_features = temporal_features + spatial_features  # Example of simple addition
        return combined_features  # Placeholder return


def train(model, data_loader, optimizer, loss_fn, device):
    model.train()  # Set the model to training mode
    total_loss = 0  # Track the total loss for monitoring

    for batch_idx, (textual_query, video_content, labels) in enumerate(data_loader):
        textual_query = textual_query.to(device)
        video_content = video_content.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()  # Clear previous gradients
        outputs = model(textual_query, video_content)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()  # Update total loss
        if batch_idx % 100 == 0:  # Print every 100 batches
            print(f'Batch {batch_idx}/{len(data_loader)}, Loss: {loss.item()}')

    avg_loss = total_loss / len(data_loader)
    print(f'Average Loss: {avg_loss}')

    return avg_loss  # Optionally return the average loss for monitoring

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = TextVideoRetrievalModel().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4)
loss_fn = nn.CrossEntropyLoss()
