import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
from torch.cuda.amp import autocast

class VisionTextAlignmentModel(nn.Module):
    def __init__(self, vision_dim, text_dim, projection_dim=4096):
        super().__init__()
        self.vision_projection = nn.Linear(vision_dim, projection_dim)
        self.text_projection = nn.Linear(text_dim, projection_dim)
        
        self.vision_decoder = nn.Linear(projection_dim, vision_dim)
        self.text_decoder = nn.Linear(projection_dim, text_dim)

    def encoder(self, vision_features=None, text_features=None):
        vision_embed, text_embed = None, None
        if vision_features is not None:
            vision_embed = self.vision_projection(vision_features)
        if text_features is not None:
            text_embed = self.text_projection(text_features)
        return vision_embed, text_embed
    def decoder(self, vision_embed=None, text_embed=None):
        vision_recon, text_recon = None, None
        if vision_embed is not None:
            vision_recon = self.vision_decoder(vision_embed)
        if text_embed is not None:
            text_recon = self.text_decoder(text_embed)
        return vision_recon, text_recon

    def forward(self, vision_features=None, text_features=None):
        vision_embed, text_embed = self.encoder(vision_features, text_features)
        vision_recon, text_recon = self.decoder(vision_embed, text_embed)
        return vision_embed, text_embed, vision_recon, text_recon

def contrastive_loss(vision_embed, text_embed, temperature=0.07):

    #vision_embed = F.normalize(vision_embed, dim=-1)
    #text_embed = F.normalize(text_embed, dim=-1)
    sim_matrix = torch.matmul(vision_embed, text_embed.transpose(0, 1)) / temperature
    

    labels = torch.arange(vision_embed.shape[0], device=vision_embed.device)
    

    loss = F.cross_entropy(sim_matrix, labels) + F.cross_entropy(sim_matrix.transpose(0, 1), labels)
    return loss / 2.0

class AlignmentTrainer:
    def __init__(self, vision_dim, text_dim, config=None):
        self.config = {
            'lr': 1e-4,
            'batch_size': 64,
            'num_epochs': 100,
            'warmup_steps': 100,
            'weight_decay': 0.01,
            'temperature': 0.07,
        }
        if config:
            self.config.update(config)
            
        self.model = VisionTextAlignmentModel(vision_dim, text_dim).to(device)
        self.optimizer = torch.optim.AdamW(
            self.model.parameters(), 
            lr=self.config['lr'],
            weight_decay=self.config['weight_decay']
        )
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer, 
            T_max=self.config['num_epochs']
        )
        
    def train_epoch(self, train_loader):
        self.model.train()
        total_loss = 0
        total_contrast_loss = 0
        total_recon_loss = 0
        
        for batch_idx, (vision_features, text_features) in enumerate(train_loader):
            self.optimizer.zero_grad()
            with autocast():
                vision_embed, text_embed, vision_recon, text_recon = self.model(vision_features.to(device), text_features.to(device))
            
            contrast_loss = contrastive_loss(vision_embed, text_embed, self.config['temperature'])
            vision_recon_loss = nn.MSELoss()(vision_recon, vision_features.to(device))
            text_recon_loss = nn.MSELoss()(text_recon, text_features.to(device))
            recon_loss = vision_recon_loss + text_recon_loss
            loss = 1*contrast_loss + recon_loss
            
            loss.backward()
            self.optimizer.step()
            
            total_loss += loss.item()
            total_contrast_loss += contrast_loss.item()
            total_recon_loss += recon_loss.item()
            
        return {
            'total_loss': total_loss / len(train_loader),
            'contrast_loss': total_contrast_loss / len(train_loader),
            'recon_loss': total_recon_loss / len(train_loader)
        }
    
        
    def evaluate(self, val_loader):
        self.model.eval()
        total_loss = 0
        total_contrast_loss = 0
        total_recon_loss = 0
        total_correct = 0
        total_samples = 0
        
        with torch.no_grad():
            for batch_idx, (val_vision_features, val_text_features) in enumerate(val_loader):
                with autocast():
                    vision_embed, text_embed, vision_recon, text_recon = self.model(
                        val_vision_features.to(device), 
                        val_text_features.to(device)
                    )
                
                contrast_loss = contrastive_loss(vision_embed, text_embed, self.config['temperature'])
                vision_recon_loss = nn.MSELoss()(vision_recon, val_vision_features.to(device))
                text_recon_loss = nn.MSELoss()(text_recon, val_text_features.to(device))
                recon_loss = vision_recon_loss + text_recon_loss
                loss = 1*contrast_loss + recon_loss
                
                similarity = torch.matmul(vision_embed, text_embed.transpose(0, 1))
                predictions = similarity.argmax(dim=-1)
                labels = torch.arange(len(predictions), device=predictions.device)
                
                total_loss += loss.item() * len(predictions)
                total_contrast_loss += contrast_loss.item() * len(predictions)
                total_recon_loss += recon_loss.item() * len(predictions)
                total_correct += (predictions == labels).sum().item()
                total_samples += len(predictions)
        
        return {
            'avg_loss': total_loss / total_samples,
            'contrast_loss': total_contrast_loss / total_samples,
            'recon_loss': total_recon_loss / total_samples,
            'accuracy': total_correct / total_samples
        }

if __name__ == "__main__":

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    config = {
        'lr': 5e-5,
        'batch_size': 2048,
        'num_epochs': 50,
        'temperature': 0.07
    }

    trainer = AlignmentTrainer(vision_dim=4096, text_dim=4096, config=config)
    trainer.model = trainer.model.to(device)

    embeddings_data = torch.load("../representation_collection/lvlms/activations/llava_cc3m_activations_model.layers.30_mean.pt")

    text_embeddings = torch.Tensor(np.stack(embeddings_data['text_features'], axis=0)).squeeze().half()
    image_embeddings = torch.Tensor(np.stack(embeddings_data['image_features'], axis=0)).squeeze().half()

    total_samples = len(text_embeddings)
    train_ratio = 0.8
    indices = np.random.permutation(total_samples)
    train_size = int(total_samples * train_ratio)

    train_indices = indices[:train_size]
    val_indices = indices[train_size:]

    train_text_embeddings = text_embeddings[train_indices]
    train_image_embeddings = image_embeddings[train_indices]
    val_text_embeddings = text_embeddings[val_indices]
    val_image_embeddings = image_embeddings[val_indices]


    train_dataset = TensorDataset(train_image_embeddings, train_text_embeddings)
    val_dataset = TensorDataset(val_image_embeddings, val_text_embeddings)
    train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=config['batch_size'], shuffle=False)


    best_val_loss = float('inf')
    for epoch in range(config['num_epochs']):

        train_metrics = trainer.train_epoch(train_loader)
        

        val_metrics = trainer.evaluate(val_loader)
        
        trainer.scheduler.step()
        

        print(f"\nEpoch [{epoch+1}/{config['num_epochs']}]")
        print(f"Training Metrics:")
        print(f"  Total Loss: {train_metrics['total_loss']:.4f}")
        print(f"  Contrast Loss: {train_metrics['contrast_loss']:.4f}")
        print(f"  Recon Loss: {train_metrics['recon_loss']:.4f}")
        print(f"Validation Metrics:")
        print(f"  Total Loss: {val_metrics['avg_loss']:.4f}")
        print(f"  Contrast Loss: {val_metrics['contrast_loss']:.4f}")
        print(f"  Recon Loss: {val_metrics['recon_loss']:.4f}")
        print(f"  Accuracy: {val_metrics['accuracy']*100:.2f}%")

        if val_metrics['avg_loss'] < best_val_loss:
            best_val_loss = val_metrics['avg_loss']
            torch.save(trainer.model.state_dict(), './llava_alignment_model_best.pt')
        
        torch.save({
            'epoch': epoch,
            'model_state_dict': trainer.model.state_dict(),
            'optimizer_state_dict': trainer.optimizer.state_dict(),
            'train_metrics': train_metrics,
            'val_metrics': val_metrics,
        }, './llava_alignment_model_last.pt')