import argparse
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm
import numpy as np
from sae_model import SAE, VL_SAE, VL_SAE_COS, SAE_D, SAE_V, VL_SAE_COS_LLaVA, VL_SAE_CON, VL_SAE_DIS
from torch.cuda.amp import autocast
from llava_alignment_model_trainer import VisionTextAlignmentModel

parser = argparse.ArgumentParser(description='Train SAE model')
parser.add_argument('--device', type=str, default='cuda:0', help='device to use')
parser.add_argument('--topk', type=int, default=256, help='top k for sparse coding')
parser.add_argument('--hidden_ratio', type=int, default=8, help='hidden dimension ratio')
parser.add_argument('--sae_type', type=str, default='saev', help='SAE model type')
parser.add_argument('--num_epochs', type=int, default=100, help='number of epochs')
parser.add_argument('--warmup_epochs', type=int, default=20, help='number of warmup epochs')
parser.add_argument('--batch_size', type=int, default=512, help='batch size')
parser.add_argument('--alpha', type=float, default=1e-2, help='weight for latent loss')
parser.add_argument('--initial_lr', type=float, default=1e-4, help='initial learning rate')
parser.add_argument('--weight_decay', type=float, default=0, help='weight decay for optimizer')
parser.add_argument('--patience', type=int, default=10, help='patience for early stopping')
parser.add_argument('--train_ratio', type=float, default=0.8, help='ratio of training data')

args = parser.parse_args()

device = args.device
topk = args.topk
hidden_ratio = args.hidden_ratio
sae_type = args.sae_type

embeddings_data = torch.load("../representation_collection/llava_models/activations/llava_cc3m_activations_model.layers.30_mean.pt")

text_embeddings = torch.Tensor(np.stack(embeddings_data['text_features'], axis=0)).squeeze().half()
image_embeddings = torch.Tensor(np.stack(embeddings_data['image_features'], axis=0)).squeeze().half()

print("Text_embeddings: {}, image embeddings: {}".format(text_embeddings.shape, image_embeddings.shape))
input_dim = text_embeddings.shape[1]
hidden_dim = input_dim * args.hidden_ratio
num_epochs = args.num_epochs
warmup_epochs = args.warmup_epochs
batch_size = args.batch_size
initial_lr = args.initial_lr
weight_decay = args.weight_decay
patience = args.patience

alignment_model = None
if sae_type == 'vlsae':
    autoencoder = VL_SAE_COS(input_dim, hidden_dim, topk=topk).to(device)
    alignment_model = VisionTextAlignmentModel(vision_dim=input_dim, text_dim=input_dim).to(device)
    ckpt = torch.load('./llava_alignment_model_best.pt')
    alignment_model.load_state_dict(ckpt)
elif sae_type == 'vlsae_dis':
    autoencoder = VL_SAE_DIS(input_dim, hidden_dim, topk=topk).to(device)
elif sae_type == 'saed':
    autoencoder = SAE_D(input_dim, hidden_dim, topk=topk).to(device)
elif sae_type == 'saev':
    autoencoder = SAE_V(input_dim, hidden_dim, topk=topk).to(device)
elif sae_type == 'vlsae_con':
    autoencoder = VL_SAE_CON(input_dim, hidden_dim, topk=topk).to(device)

criterion = nn.MSELoss()
optimizer = optim.Adam(autoencoder.parameters(), lr=initial_lr, weight_decay=weight_decay)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs - warmup_epochs)


total_samples = len(text_embeddings)
train_ratio = args.train_ratio
indices = np.random.permutation(total_samples)
train_size = int(total_samples * train_ratio)

train_indices = indices[:train_size]
val_indices = indices[train_size:]

train_text_embeddings = text_embeddings[train_indices]
train_image_embeddings = image_embeddings[train_indices]
val_text_embeddings = text_embeddings[val_indices]
val_image_embeddings = image_embeddings[val_indices]

def validate(model, val_text_embeddings, val_image_embeddings, criterion, batch_size, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for i in range(0, len(val_text_embeddings), batch_size):
            batch_embeddings_v = val_image_embeddings[i:i+batch_size].to(device)
            batch_embeddings_t = val_text_embeddings[i:i+batch_size].to(device)
            # batch_embeddings_v = F.normalize(batch_embeddings_v, dim=-1)
            # batch_embeddings_t = F.normalize(batch_embeddings_t, dim=-1)
            with autocast():
                if alignment_model is not None:
                    batch_embeddings_v_in, batch_embeddings_t_in, _, _ = alignment_model(batch_embeddings_v, batch_embeddings_t)
                else:
                    batch_embeddings_v_in, batch_embeddings_t_in = batch_embeddings_v, batch_embeddings_t
                recon_v, recon_t, _, _ = model(batch_embeddings_v_in, batch_embeddings_t_in)
            loss = criterion(recon_v, batch_embeddings_v_in) + criterion(recon_t, batch_embeddings_t_in)
            total_loss += loss.item() * batch_embeddings_v.size(0)
    avg_loss = total_loss / len(val_text_embeddings)
    model.train()
    return avg_loss

best_val_loss = float('inf')
patience_counter = 0
autoencoder.train()

num_steps = len(train_text_embeddings) // batch_size
for epoch in range(num_epochs):
    if epoch < warmup_epochs:
        lr = initial_lr * (epoch + 1) / warmup_epochs
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
    else:
        scheduler.step()

    epoch_loss = 0
    latent_loss = 0
    for i in tqdm(range(0, len(train_text_embeddings), batch_size)):
        optimizer.zero_grad()

        batch_embeddings_v = train_image_embeddings[i:i + batch_size].to(device)
        batch_embeddings_t = train_text_embeddings[i:i + batch_size].to(device)
        # batch_embeddings_v = F.normalize(batch_embeddings_v, dim=-1)
        # batch_embeddings_t = F.normalize(batch_embeddings_t, dim=-1)
        with autocast():
            # print(batch_embeddings_v.dtype)
            with torch.no_grad():
                if alignment_model is not None:
                    batch_embeddings_v_in, batch_embeddings_t_in, _, _ = alignment_model(batch_embeddings_v, batch_embeddings_t)
                else:
                    batch_embeddings_v_in, batch_embeddings_t_in = batch_embeddings_v, batch_embeddings_t
            batch_embeddings_v_in = batch_embeddings_v_in.half()
            batch_embeddings_t_in = batch_embeddings_t_in.half()
            recon_v, recon_t, _, _ = autoencoder(batch_embeddings_v_in, batch_embeddings_t_in, mode='train')

        recon_loss = criterion(recon_v, batch_embeddings_v_in) + criterion(recon_t, batch_embeddings_t_in)# + args.alpha * con_loss
        
        recon_loss.backward()
        optimizer.step()
        epoch_loss += recon_loss.item()# - args.alpha * con_loss.item()



    val_loss = validate(autoencoder, val_text_embeddings, val_image_embeddings, 
                       criterion, batch_size, device)
    
    lr = optimizer.param_groups[0]['lr']
    print(f'Epoch [{epoch + 1}/{num_epochs}], LR: {lr}, '
          f'Train Loss: {epoch_loss/num_steps:.4f}, '
          f'Val Loss: {val_loss:.4f}')
    

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
        torch.save(autoencoder.state_dict(), 
                  f'./sae_weights/llava_mean_{sae_type}_{topk}_{hidden_ratio}_best.pth')
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print(f'Early stopping triggered after epoch {epoch + 1}')
            break


torch.save(autoencoder.state_dict(), 
          f'./sae_weights/llava_mean_{sae_type}_{topk}_{hidden_ratio}_final.pth')