import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import clip
from eva_vit import create_eva_vit_g
from processors.blip_processors import Blip2ImageTrainProcessor
from predefined_concepts import concepts 



class ConceptDataset(Dataset):
    def __init__(self, image_dir, concepts, transform=None, precomputed_probs=None):
        self.image_dir = image_dir
        self.image_files = os.listdir(image_dir)
        self.concepts = concepts
        self.transform = transform

        if torch.cuda.is_available():
            device_id = torch.cuda.current_device()
            self.device = f"cuda:{device_id}"
        else:
            self.device = "cpu"
            
        print("=====Precomputing probabilities")
        self.clip_model, self.clip_preprocess = clip.load("ViT-L/14@336px", device=self.device)
        self.text = clip.tokenize(self.concepts).to(self.device)
        from tqdm import tqdm 
        self.image_probs = {}
        batch_size = 64
        
        for i in tqdm(range(0, len(self.image_files), batch_size)):
            batch_files = self.image_files[i:i + batch_size]
            batch_images = []
            for img_name in batch_files:
                image_path = os.path.join(self.image_dir, img_name)
                image = Image.open(image_path).convert('RGB')
                clip_image = self.clip_preprocess(image).to(self.device)
                batch_images.append(clip_image)
            
            batch_images = torch.stack(batch_images)
            with torch.no_grad():
                logits_per_image, _ = self.clip_model(batch_images, self.text)
                probs = (logits_per_image / 0.5).softmax(dim=-1)
                
                for idx, img_name in enumerate(batch_files):
                    self.image_probs["flickr8k-images/" +  img_name] = probs[idx].detach().cpu()
        import pickle
        probs_save_path = os.path.join("./prepare_concept/precomputed", f'flickr8k_precomputed_probs.pkl')
        with open(probs_save_path, 'wb') as f:
            pickle.dump(self.image_probs, f)
        print(f"Precomputed probability has been saved to: {probs_save_path}")
        del self.clip_model
        del self.text
        torch.cuda.empty_cache()
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        img_name = self.image_files[idx]
        image_path = os.path.join(self.image_dir, img_name)
        image = Image.open(image_path).convert('RGB')
        probs = self.image_probs["flickr8k-images/" + img_name]
        if self.transform:
            image = self.transform(image)
        return image, probs


def main():
    
    image_dir = "/pretrained/lavis_cache/flickr8k/images/flickr8k-images"
    vis_processor = Blip2ImageTrainProcessor()
    print("=====Loading model")
    train_dataset = ConceptDataset(image_dir, concepts, transform=vis_processor)
    print(train_dataset.image_probs)
main()