import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import os
from PIL import Image
import clip
import numpy as np
from torchvision import transforms
from utils.evaluation import evaluate_retrieval_clip
class CLIPModel(nn.Module):
    
    def __init__(self, clip_model_name="ViT-B/32", device=None):
        super().__init__()
        
        if device is None:
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
       
        self.clip_model, self.preprocess = clip.load(clip_model_name,device=device)
        self.feature_dim = self.clip_model.visual.output_dim
        self.image_projection = nn.Linear(self.feature_dim, self.feature_dim)
        self.text_projection = nn.Linear(self.feature_dim, self.feature_dim)

        self.match_head = nn.Sequential(
            nn.Linear(self.feature_dim * 2, self.feature_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(self.feature_dim, 1)
        )
        
        self.device = device
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1/0.07))
    

    def encode_image(self, images):
        if isinstance(images, list):
            processed_images = torch.stack([self.preprocess(img) for img in images]).to(self.device)
        else:
            processed_images = images
        with torch.no_grad():
            image_features = self.clip_model.encode_image(processed_images)

        image_features = image_features.float()
        image_features = self.image_projection(image_features)
        return image_features

    def encode_text(self, texts):
        if isinstance(texts, str):
            texts = [texts]

        text_tokens = clip.tokenize(texts, truncate=True).to(self.device)
        with torch.no_grad():
            text_features = self.clip_model.encode_text(text_tokens)

        text_features = text_features.float()
        text_features = self.text_projection(text_features)
        return text_features

    
    def forward(self, images, texts):
        image_features = self.encode_image(images)
        text_features = self.encode_text(texts)

        image_features_norm = F.normalize(image_features, p=2, dim=1)
        text_features_norm = F.normalize(text_features, p=2, dim=1)

        combined_features = torch.cat([image_features, text_features], dim=1)
        match_scores = self.match_head(combined_features)
        
        return match_scores, image_features_norm, text_features_norm
    
    def compute_loss(self, images, texts, match_labels=None, temperature=0.07):
        
        batch_size = len(texts) if isinstance(texts, list) else images.size(0)
        
        match_scores, image_features, text_features = self(images, texts)

        logits_per_image = (image_features @ text_features.t()) * self.logit_scale.exp()
        #logits_per_image = image_features @ text_features.t() / temperature
        logits_per_text = logits_per_image.t()
   
        labels = torch.arange(batch_size, device=self.device)
        
        loss_i2t = F.cross_entropy(logits_per_image, labels)
        loss_t2i = F.cross_entropy(logits_per_text, labels)
        contrastive_loss = (loss_i2t + loss_t2i) / 2

        if match_labels is not None:
            match_loss = F.binary_cross_entropy_with_logits(
                match_scores.squeeze(-1), 
                match_labels.float()
            )
        else:
            
            match_loss = F.binary_cross_entropy_with_logits(
                match_scores.squeeze(-1), 
                torch.ones(batch_size, device=self.device)
            )

        total_loss = contrastive_loss + 0.1 * match_loss  
        
        return {
            "loss": total_loss,
            "contrastive_loss": contrastive_loss,
            "match_loss": match_loss,
            "i2t_loss": loss_i2t,
            "t2i_loss": loss_t2i
        }
    
    def extract_features(self, sample):
    
        if "image" in sample:
            images = sample["image"]
        else:
            images = sample.get("images", sample.get("image_tensor"))
        
        if "text_input" in sample:
            texts = sample["text_input"]
        elif "text" in sample:
            texts = sample["text"]
        else:
            texts = sample.get("texts", sample.get("captions"))

        if isinstance(texts, str):
            texts = [texts]
        elif isinstance(texts, torch.Tensor):
            texts = [f"text_{i}" for i in range(texts.size(0))]  # 占位符
        
        image_features = self.encode_image(images)
        text_features = self.encode_text(texts)

        image_features = F.normalize(image_features, p=2, dim=1)
        text_features = F.normalize(text_features, p=2, dim=1)
        
        return {
            "image_embeds": image_features,
            "text_embeds": text_features
        }


class CLIPDataset(Dataset):
    def __init__(self, data, data_dir, transform=None):
        self.data = data
        self.data_dir = data_dir
        self.transform = transform
        
        if self.transform is None:
            self.transform = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], 
                                   std=[0.26862954, 0.26130258, 0.27577711]),
            ])
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]

        if 'gen' in item['image_path']:
            img_path = item['image_path']
        elif 'cc3m_kb' in item['image_path']:
            img_path = item['image_path']
        else:
            img_path = os.path.join(self.data_dir, item['image_path'])
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)

        text = item['caption']
        
        return {
            'image': image,
            'text': text,
            'image_id': item['image_id']
        }


def train_clip_model(data,test_data, data_dir, epochs=10, device=None):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    print("Initializing CLIP model...")
    model = CLIPModel(device=device).to(device)

    dataset = CLIPDataset(data, data_dir)
    dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=2)
    trainable_params = list(model.image_projection.parameters()) + \
                      list(model.text_projection.parameters()) + \
                      list(model.match_head.parameters())
    optimizer = optim.AdamW(trainable_params, lr=1e-4, weight_decay=0.01)

    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    results_best = {}
    results_best['i2t_R@1'] = 0.0
    results_best['t2i_R@1'] = 0.0
    results_best['i2t_R@5'] = 0.0
    results_best['t2i_R@5'] = 0.0

    model.train()
    for epoch in range(epochs):
        total_loss = 0
        num_batches = 0
        
        for i, batch in enumerate(dataloader):
            
            images = batch['image'].to(device)
            texts = batch['text']  
            
            optimizer.zero_grad()
            loss_dict = model.compute_loss(images, texts)
            loss = loss_dict["loss"]
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            num_batches += 1
            
            if i % 20 == 0:
                print(f"Epoch {epoch+1}/{epochs}, Batch {i}, "
                        f"Loss: {loss.item():.4f}, "
                        f"Contrastive: {loss_dict['contrastive_loss'].item():.4f}, "
                        f"Match: {loss_dict['match_loss'].item():.4f}")
                
        scheduler.step()
        
        if num_batches > 0:
            avg_loss = total_loss / num_batches
            print(f"CLIP Epoch {epoch+1}/{epochs}, Avg Loss: {avg_loss:.4f}")
        else:
            print(f"CLIP Epoch {epoch+1}/{epochs}, No valid batches")
        
    return model,results_best

def train_albef(clip_model, data, data_dir, epochs=10, device=None):
    print("Using CLIP model instead of ALBEF...")
    return train_clip_model(clip_model, data, data_dir, epochs, device)
