import os, json, random
import torch
from torch.utils.data import Dataset
from PIL import Image
import numpy as np
import clip
import torch.nn.functional as F
from utils.niqe import calculate_niqe
from utils.easy import AlignParams
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction

def map_external(v_ext: torch.Tensor, p: AlignParams) -> torch.Tensor:
    return ((v_ext - p.mu_ext) / p.scale_ext) @ p.W.T

def preprocess_internal(v_int: torch.Tensor, p: AlignParams) -> torch.Tensor:
    return (v_int - p.mu_int) / p.scale_int
class CLIPDataset(Dataset):
   
    def __init__(self, data, data_dir, transform=None):
        self.data = data
        self.data_dir = data_dir
        self.transform = transform

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        image = None
        if item['modality'] != 'text_only':
            if "cc3m_kb" in item['image_path']:
                img_path = item['image_path']
            else:
                img_path = os.path.join(self.data_dir, item['image_path'])
            image = Image.open(img_path).convert('RGB')
            if self.transform:
                image = self.transform(image)
        else:
            image = torch.zeros(3, 224, 224) if self.transform else Image.new('RGB', (224, 224))

        text = item['caption'] if item['modality'] != 'image_only' else ""
        
        return {
            'image': image,
            'text': text,
            'modality': item['modality'],
            'image_id': item['image_id']
        }


def top_k_search(query_feat, kb_feats, kb_items, k=5):
    sims = [F.cosine_similarity(query_feat, feat.unsqueeze(0)).item() for feat in kb_feats]
    topk_idx = sorted(range(len(sims)), key=lambda i: sims[i], reverse=True)[:k]
    return [kb_items['meta'][i] for i in topk_idx], [sims[i] for i in topk_idx]


from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction

def augment_dataset_img_txt_easy(generator, 
                            clip_model,
                            mlp_model,
                            processor, 
                            data, 
                            data_dir,
                            img_save_dir,
                            data_save_dir,
                            internal_kb, 
                            params_img,
                            params_txt,
                            device, 
                            k=5, 
                            n=10,
                            lambda_score=0.5):

    augmented_data = {
        'image_only': [],
        'text_only': [],
        'multimodal': []
    }
    clip_model.eval()
    mlp_model.eval()    
    
    for item in data['image_only']:
        img_path = os.path.join(data_dir, item['image_path'])
        with torch.no_grad():
            image = processor(Image.open(img_path).convert("RGB")).unsqueeze(0).to(device)
            image_feat = clip_model.encode_image(image)
            image_feat = image_feat / image_feat.norm(dim=-1, keepdim=True)
            pseudo_txt_feat = mlp_model(img_emb=image_feat.float(), direction='img_to_text')
            pseudo_txt_feat = pseudo_txt_feat / pseudo_txt_feat.norm(dim=-1, keepdim=True)
            
        kb_img_feats = internal_kb['image_embeds'].to(device) 
        kb_txt_feats = internal_kb['text_embeds'].to(device) 

        v_image_feat = preprocess_internal(image_feat.cpu(), params_img)
        sim_scores = F.cosine_similarity(v_image_feat.to(device), kb_img_feats)  
        topk_scores, topk_indices = torch.topk(sim_scores, k)  
        top_candidates = [internal_kb['meta'][i] for i in topk_indices]       
        top_img_feats = torch.stack([internal_kb['text_embeds'][i] for i in topk_indices]).to(device)
        
        re_ranked = []
        v_pseudo_txt_feat = preprocess_internal(pseudo_txt_feat.cpu(), params_txt)
        
        sim_img = F.cosine_similarity(v_pseudo_txt_feat.to(device), top_img_feats)
        alpha=0.5  
        for i, candidates in enumerate(top_candidates):
            sim_txt = topk_scores[i].item()
            fused_score = alpha * sim_txt + (1 - alpha) * sim_img[i].item()
            re_ranked.append((fused_score, candidates))

        re_ranked = sorted(re_ranked, key=lambda x: x[0], reverse=True)
        best_item_1 = re_ranked[0][1]
         
        if best_item_1!=None:
            sample_txt_1 = best_item_1['caption']
            prompt = f"Please write two captions of the image. Caption 1: '{sample_txt_1}'. Caption 2:"
            
            generated_texts = []
            for i in range(n):
                quotes = generator.image_to_text(img_path, prompt)
                if quotes!='' and len(quotes)>5:
                    a = quotes.split("\'")[1]
                    if len(a)<77:
                        generated_texts.append(a)
            
            if len(generated_texts)!=0:
                with torch.no_grad():
                    text_tokens = clip.tokenize(generated_texts).to(device)
                    text_features = clip_model.encode_text(text_tokens)
                    text_features = text_features / text_features.norm(dim=-1, keepdim=True)

                similarity_scores = (pseudo_txt_feat @ text_features.T.float()).squeeze(0)  # [5]
                
                best_score = 0.0
                best_text = None
                lam = 0.3
                for i, text in enumerate(generated_texts):
                    bleu_score = sentence_bleu(sample_txt_1, text.split(), 
                                        smoothing_function=SmoothingFunction().method1)
                    score = lam * bleu_score+(1-lam) * similarity_scores[i]
                    if score > best_score:
                        best_score = score
                        best_text = text

                augmented_data['multimodal'].append({
                    'image_id': item['image_id'],
                    'image_path': item['image_path'],
                    'caption': best_text,
                    'modality': 'multimodal',
                    'generated': True,
                    'original_modality': 'image_only'
                })

    os.makedirs(img_save_dir, exist_ok=True)
    idx =0
    for item in data['text_only']:
        text = item['caption'][:77]
        with torch.no_grad():
            text_tokens = clip.tokenize(text).to(device)
            text_feat = clip_model.encode_text(text_tokens)
            text_feat = text_feat / text_feat.norm(dim=-1, keepdim=True)
            pseudo_img_feat = mlp_model(text_emb=text_feat.float(), direction='text_to_img')
            pseudo_img_feat = pseudo_img_feat / pseudo_img_feat.norm(dim=-1, keepdim=True)
            
        kb_img_feats = internal_kb['image_embeds'].to(device) 
        kb_txt_feats = internal_kb['text_embeds'].to(device) 
    
        v_text_feat = preprocess_internal(text_feat.cpu(), params_txt)
        
        sim_scores = F.cosine_similarity(v_text_feat.to(device), kb_txt_feats)  # [N]
        topk_scores, topk_indices = torch.topk(sim_scores, k) 
        top_candidates = [internal_kb['meta'][i] for i in topk_indices]        
        top_img_feats = torch.stack([internal_kb['image_embeds'][i] for i in topk_indices]).to(device)
        
        re_ranked = []
        v_pseudo_img_feat = preprocess_internal(pseudo_img_feat.cpu(), params_img)
        sim_img = F.cosine_similarity(v_pseudo_img_feat.to(device), top_img_feats)
        alpha=0.5   
        for i, candidates in enumerate(top_candidates):
            sim_txt = topk_scores[i].item()
            fused_score = alpha * sim_txt + (1 - alpha) * sim_img[i].item()
            re_ranked.append((fused_score, candidates))

        re_ranked = sorted(re_ranked, key=lambda x: x[0], reverse=True)
        best_item_1 = re_ranked[0][1]
        
           
        if best_item_1!=None:
            sample_img_1 = best_item_1['image_path']
            generated_imgs = generator.text_to_image(sample_img_1, text,n)
            best_score = 0.0
            best_img = None
            λ = 0.3  
            for i, img in enumerate(generated_imgs):
                img_tensor = processor(img).unsqueeze(0).to(device)
                with torch.no_grad():
                    image_feat = clip_model.encode_image(img_tensor)
                    image_feat = image_feat / image_feat.norm(dim=-1, keepdim=True)
                    cosine_sim = F.cosine_similarity(image_feat.float(), pseudo_img_feat).item()
                    niqe_score=calculate_niqe(np_img, crop_border=0, params_path='./utils')

                niqe_score_norm = max(0, 10 - niqe_score) / 10.0
                score = λ * niqe_score_norm + (1 - λ) * cosine_sim
                
                if score > best_score:
                    best_score = score
                    best_img = img

            os.makedirs(os.path.dirname(img_save_dir) if os.path.dirname(img_save_dir) else '.', exist_ok=True)
            output_path = f"{img_save_dir}/generated_{idx}.png"
            best_img.save(output_path)
                
            n=n+1
            augmented_data['multimodal'].append({
                'image_id': item['image_id'],
                'image_path': output_path,  
                'caption': text,
                'modality': 'multimodal',
                'generated': True,
                'original_modality': 'text_only'
            }) 
    
    print(f"Augmented dataset: {len(augmented_data['multimodal'])} multimodal samples")
            
    with open(data_save_dir,'w',encoding='utf-8') as f:
        json.dump(augmented_data, f, ensure_ascii=False, indent=2)
    
    return augmented_data

