import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import clip
from utils.dataset import CLIPDataset
import os
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import numpy as np

def tsne_visualization(mlp_model, clip_model, dataloader, device,save_path, max_samples=300):
    mlp_model.eval()
    clip_model.eval()

    img_feats, txt_feats, pseudo_txt_feats = [], [], []

    count = 0
    for batch in dataloader:
        images = batch['image'].to(device)
        texts = clip.tokenize(batch['text'], truncate=True).to(device)

        with torch.no_grad():
            image_features = clip_model.encode_image(images).float()
            text_features = clip_model.encode_text(texts).float()
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)

            mapped_txt_features = mlp_model(img_emb=image_features, direction='img_to_text')

        img_feats.append(image_features.cpu())
        txt_feats.append(text_features.cpu())
        pseudo_txt_feats.append(mapped_txt_features.cpu())

        count += image_features.size(0)
        if count >= max_samples:
            break

    img_feats = torch.cat(img_feats, dim=0).numpy()
    txt_feats = torch.cat(txt_feats, dim=0).numpy()
    pseudo_txt_feats = torch.cat(pseudo_txt_feats, dim=0).numpy()

    all_features = np.concatenate([img_feats, txt_feats, pseudo_txt_feats], axis=0)
    tsne = TSNE(n_components=2, perplexity=30, learning_rate=200, random_state=42)
    tsne_result = tsne.fit_transform(all_features)

    n = img_feats.shape[0]
    tsne_img = tsne_result[:n]
    tsne_txt = tsne_result[n:2*n]
    tsne_pseudo_txt = tsne_result[2*n:]

    plt.figure(figsize=(10, 8))
    plt.scatter(tsne_img[:, 0], tsne_img[:, 1], c='blue', label='Image Feature', alpha=0.5, s=20)
    plt.scatter(tsne_txt[:, 0], tsne_txt[:, 1], c='green', label='Text Feature', alpha=0.5, s=20)
    plt.scatter(tsne_pseudo_txt[:, 0], tsne_pseudo_txt[:, 1], c='red', label='Pseudo Text (Mapped)', alpha=0.5, s=20)

    plt.title('t-SNE of Image, Text, and Mapped Text Features')
    plt.legend()
    plt.grid(True)
    plt.savefig(save_path)
class ResidualMLPBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(dim, dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(dim, dim)
        )
    
    def forward(self, x):
        return x + self.fc(x)
class BidirectionalMLP(nn.Module):
    def __init__(self, input_dim=512, hidden_dim=1024):
        super().__init__()

        self.img_to_text = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, input_dim)
        )

        self.text_to_img = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, input_dim)
        )
    
    def forward(self, img_emb=None, text_emb=None, direction='img_to_text'):
        if direction == 'img_to_text' and img_emb is not None:
            return self.img_to_text(img_emb)
        elif direction == 'text_to_img' and text_emb is not None:
            return self.text_to_img(text_emb)
        else:
            raise ValueError("Invalid direction or missing embeddings")


class ResidualBidirectionalMLP(nn.Module):
    def __init__(self, input_dim=512, hidden_dim=1024):
        super().__init__()

        self.img_to_text = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            ResidualMLPBlock(hidden_dim),
            nn.Linear(hidden_dim, input_dim)
        )

        self.text_to_img = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            ResidualMLPBlock(hidden_dim),
            nn.Linear(hidden_dim, input_dim)
        )
    def forward(self, img_emb=None, text_emb=None, direction='img_to_text'):
        if direction == 'img_to_text' and img_emb is not None:
            return self.img_to_text(img_emb)
        elif direction == 'text_to_img' and text_emb is not None:
            return self.text_to_img(text_emb)
        else:
            raise ValueError("Invalid direction or missing embeddings")

def train_bidirectional_mlp(clip_model, train_data,test_data, data_dir,dataname,ratio, epochs=20, device=None):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    clip_model.eval()
    dataset = CLIPDataset(train_data.get('multimodal'), data_dir, transform=clip.load("ViT-B/32")[1])
    dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
    test_dataset = CLIPDataset(test_data.get('multimodal'), data_dir, transform=clip.load("ViT-B/32")[1])
    test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True)

    mlp_model = BidirectionalMLP().to(device)
    optimizer = optim.Adam(mlp_model.parameters(), lr=1e-4)
    criterion = nn.MSELoss()
    
    mlp_model.train()
    for epoch in range(epochs):
        total_loss = 0
        
        for batch in dataloader:
            images = batch['image'].to(device)
            texts = clip.tokenize(batch['text'], truncate=True).to(device)
            
            with torch.no_grad():
                image_features = clip_model.encode_image(images)
                text_features = clip_model.encode_text(texts)

                image_features = image_features / image_features.norm(dim=-1, keepdim=True)
                text_features = text_features / text_features.norm(dim=-1, keepdim=True)
            
            optimizer.zero_grad()
            positive_labels = torch.ones(image_features.size(0)).to(device)

            pseudo_text = mlp_model(img_emb=image_features.float(), direction='img_to_text')
            pseudo_img = mlp_model(text_emb=text_features.float(), direction='text_to_img')
            
            loss_img_to_text = criterion(pseudo_text, text_features.float())
            loss_text_to_img = criterion(pseudo_img, image_features.float())

            reconstructed_img = mlp_model(text_emb=pseudo_text, direction='text_to_img')
            reconstructed_text = mlp_model(img_emb=pseudo_img, direction='img_to_text')
            
            cycle_loss_img = criterion(reconstructed_img, image_features.float())
            cycle_loss_text = criterion(reconstructed_text, text_features.float())
            
            total_loss_batch = loss_img_to_text + loss_text_to_img + 0.5 * (cycle_loss_img + cycle_loss_text)
            total_loss_batch.backward()
            optimizer.step()
            
            total_loss += total_loss_batch.item()
        
        print(f"MLP Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(dataloader):.4f}")
     
        eval_cos = evaluate_mlp_mapping(mlp_model, clip_model, test_dataloader, device)

        if (epoch+1)%10==0 and epoch!=0:
            dir_path = f'./checkpoints/{dataname}/{ratio}/MLP_MSE'
            os.makedirs(dir_path, exist_ok=True)
            save_path = os.path.join(dir_path, f"bidirectional_mlp_epoch_{epoch+1}.pt")
            torch.save(mlp_model.state_dict(), save_path)
            eva_cos=eval_cos
            print(f"Saved model checkpoint at {save_path}")
    return mlp_model

import torch.nn.functional as F

def evaluate_mlp_mapping(mlp_model, clip_model, dataset, device):
    mlp_model.eval()
    clip_model.eval()

    total_mse_img2text = 0.0
    total_cos_img2text = 0.0
    total_samples = 0

    for sample in dataset:
        
        image = sample['image'].to(device)
        text = clip.tokenize(sample['text'], truncate=True).to(device)

        with torch.no_grad():
            img_feat = clip_model.encode_image(image).float()
            txt_feat = clip_model.encode_text(text).float()

            img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)
            txt_feat = txt_feat / txt_feat.norm(dim=-1, keepdim=True)

        with torch.no_grad():
            mapped_txt = mlp_model(img_emb=img_feat, direction='img_to_text')

        mse_loss = F.mse_loss(mapped_txt, txt_feat)
        cos_sim = F.cosine_similarity(mapped_txt, txt_feat)
        avg_cos_sim = cos_sim.mean().item()
        total_mse_img2text += mse_loss.item()
        total_cos_img2text += avg_cos_sim
        total_samples += 1

    print(f"Avg MSE (img→text): {total_mse_img2text / total_samples:.4f}")
    print(f"Avg Cosine Sim (img→text): {total_cos_img2text / total_samples:.4f}")
    return total_cos_img2text / total_samples
