import os
import tqdm
import pytorch_lightning as pl
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
import pandas as pd
from datasets import load_dataset
from torch.utils.data import TensorDataset, DataLoader

import numpy as np

#define your neural net here:

class MLP(pl.LightningModule):
    def __init__(self, input_size, text_embedding_dim=2048, img_channels=4, img_height=64, img_width=64, xcol='emb', ycol='target_score'):
        super().__init__()
        self.input_size = input_size
        self.text_embedding_dim = text_embedding_dim
        self.xcol = xcol
        self.ycol = ycol
        
        # Image feature processing through convolution layers
        self.conv_layers = nn.Sequential(
            nn.Conv2d(img_channels, 32, kernel_size=3, stride=1, padding=1),  # 4x64x64 -> 32x64x64
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),  # 32x64x64 -> 32x32x32
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),  # 32x32x32 -> 64x32x32
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=4, stride=4),  # 64x32x32 -> 64x8x8
            nn.Flatten()
        )

        # Total input size includes the flattened image features, time embedding, and text embedding
        total_input_size = self._calculate_image_flattened_size(img_height, img_width)  + text_embedding_dim

        # MLP layers
        self.layers = nn.Sequential(
            nn.Linear(total_input_size, 1024),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(1024, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(64, 16),
            nn.ReLU(),
            nn.Linear(16, 1)
        )
    
    def _calculate_image_flattened_size(self, img_height, img_width):
        # Calculate the size after passing through the convolution layers
        # Assume the image size is img_channels x img_height x img_width
        # After 2 max pooling layers, the image size is halved each time
        h, w = img_height, img_width
        h, w = h // 2, w // 2
        h, w = h // 4, w // 4
        return 64 * h * w  # The flattened size after 2 poolings

    def forward(self, x, text_prompt):
        """
        x: Original input features (tensor)
        time_index: Time index (tensor)
        text_prompt: Indices of words in the text prompt (tensor)
        """
        
        # Process the image through convolution layers
        img_features = self.conv_layers(x)

        # Aggregate text embedding by averaging across the sequence dimension (77 tokens)
        text_emb = text_prompt.mean(dim=1)  # shape [batch_size, embedding_dim]

        # Concatenate the image features, time embedding, and text embedding
        combined_input = torch.cat([img_features, text_emb], dim=-1)
        
        # Forward pass through MLP layers
        output = self.layers(combined_input)
        # return torch.exp(output)
        return output

# class MLP_time_dependent(pl.LightningModule):
#     def __init__(self, input_size, text_embedding_dim, time_index_dim=1, img_channels=4, img_height=64, img_width=64, xcol='emb', ycol='target_score'):
#         super().__init__()
#         self.input_size = input_size
#         self.text_embedding_dim = text_embedding_dim
#         self.time_index_dim = time_index_dim
#         self.xcol = xcol
#         self.ycol = ycol
        
#         # Image feature processing through convolution layers
#         self.conv_layers = nn.Sequential(
#             nn.Conv2d(img_channels, 32, kernel_size=3, stride=1, padding=1),  # 4x64x64 -> 32x64x64
#             nn.ReLU(),
#             nn.MaxPool2d(kernel_size=2, stride=2),  # 32x64x64 -> 32x32x32
#             nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),  # 32x32x32 -> 64x32x32
#             nn.ReLU(),
#             nn.MaxPool2d(kernel_size=4, stride=4),  # 64x16x16 -> 64x4x4
#             nn.Flatten()
#         )

#         # Time step embedding
#         self.time_embedding = nn.Embedding(time_index_dim, 16)  # Time step embedding dimension is 16

#         # Total input size includes the flattened image features, time embedding, and text embedding
#         total_input_size = self._calculate_image_flattened_size(img_height, img_width) + 16 + text_embedding_dim

#         # MLP layers
#         self.layers = nn.Sequential(
#             nn.Linear(total_input_size, 1024),
#             nn.ReLU(),
#             nn.Dropout(0.2),
#             nn.Linear(1024, 128),
#             nn.ReLU(),
#             nn.Dropout(0.2),
#             nn.Linear(128, 64),
#             nn.ReLU(),
#             nn.Dropout(0.1),
#             nn.Linear(64, 16),
#             nn.ReLU(),
#             nn.Linear(16, 1)
#         )
    
#     def _calculate_image_flattened_size(self, img_height, img_width):
#         # Calculate the size after passing through the convolution layers
#         # Assume the image size is img_channels x img_height x img_width
#         # After 2 max pooling layers, the image size is halved each time
#         return 64 * (img_height // 4) * (img_width // 4)  # The flattened size after 2 poolings

#     def forward(self, x, time_index, text_prompt):
#         """
#         x: Original input features (tensor)
#         time_index: Time index (tensor)
#         text_prompt: Indices of words in the text prompt (tensor)
#         """
        
#         import ipdb
#         ipdb.set_trace()
#         # Process the image through convolution layers
#         img_features = self.conv_layers(x)

#         # Process the time index through an embedding
#         time_emb = self.time_embedding(time_index.unsqueeze(1))

#         # Aggregate text embedding by averaging across the sequence dimension (77 tokens)
#         text_emb = text_prompt.mean(dim=1)  # shape [batch_size, embedding_dim]

#         # Concatenate the image features, time embedding, and text embedding
#         combined_input = torch.cat([img_features, time_emb, text_emb], dim=-1)
        
#         # Forward pass through MLP layers
#         output = self.layers(combined_input)
#         return output
