import torch
import os
import numpy as np
import argparse
from sae_model import VL_SAE, VL_SAE_COS, SAE_D, SAE_V, VL_SAE_CON
from tqdm import tqdm
from torch.cuda.amp import autocast
from llava_alignment_model_trainer import VisionTextAlignmentModel
def parse_args():
    parser = argparse.ArgumentParser(description='VLSAE interpreting')
    parser.add_argument('--topk', type=int, default=128, help='Top k features')
    parser.add_argument('--hidden-ratio', type=int, default=32, help='Hidden dimension ratio')
    parser.add_argument('--input-dim', type=int, default=512, help='Input dimension')
    parser.add_argument('--model-type', type=str, default='llava_mean', help='Model type')
    parser.add_argument('--sae-type', type=str, default='saev', help='SAE type')
    parser.add_argument('--device', type=str, default='cuda:3', help='Device for computation')
    parser.add_argument('--num-targets', type=int, default=100, help='Number of target features to interpret')
    return parser.parse_args()

def main():
    args = parse_args()
    np.random.seed(42)
    torch.manual_seed(42)

    hidden_dim = args.input_dim * args.hidden_ratio
    alignment_model = None
    device = args.device
    if args.sae_type == 'vlsae':
        ckpt_path = f'./sae_weights/{args.model_type}_vlsae_{args.topk}_{args.hidden_ratio}_best.pth'
        autoencoder = VL_SAE_COS(args.input_dim, hidden_dim, topk=args.topk).to(device)
        alignment_model = VisionTextAlignmentModel(vision_dim=args.input_dim, text_dim=args.input_dim).to(device)
        ckpt = torch.load('./llava_alignment_model_best.pt')
        alignment_model.load_state_dict(ckpt)
    elif args.sae_type == 'saed':
        ckpt_path = f'./sae_weights/{args.model_type}_saed_{args.topk}_{args.hidden_ratio}_best.pth'
        autoencoder = SAE_D(args.input_dim, hidden_dim, topk=args.topk).to(device)
    elif args.sae_type == 'saev':
        ckpt_path = f'./sae_weights/{args.model_type}_saev_{args.topk}_{args.hidden_ratio}_best.pth'
        autoencoder = SAE_V(args.input_dim, hidden_dim, topk=args.topk).to(device)
    elif args.sae_type == 'vlsae_con':
        ckpt_path = f'./sae_weights/{args.model_type}_vlsae_con_{args.topk}_{args.hidden_ratio}_best.pth'
        autoencoder = VL_SAE_CON(args.input_dim, hidden_dim, topk=args.topk).to(device)

    ckpt = torch.load(ckpt_path, map_location=device)
    autoencoder.load_state_dict(ckpt)

    # data loading
    embeddings_data = torch.load("../representation_collection/lvlms/activations/llava_cc3m_activations_model.layers.30_mean.pt")
    text_embeddings = torch.Tensor(np.stack(embeddings_data['text_features'], axis=0)).squeeze().half()
    vision_embeddings = torch.Tensor(np.stack(embeddings_data['image_features'], axis=0)).squeeze().half()
    image_paths = embeddings_data['image_file']
    texts = embeddings_data['text']

    def get_multiple_top_activations(target_indices, embeddings, references, top_k=10, batch_size=256, modality='vision'):

        all_target_activations = {idx: [] for idx in target_indices}
        
        with torch.no_grad():
            for i in tqdm(range(0, len(embeddings), batch_size), desc="Activation Collection"):
                batch_embeddings = embeddings[i:i + batch_size].to(device)
                if alignment_model is not None:
                    if modality == 'vision':
                        with autocast():
                            batch_embeddings, _, _, _ = alignment_model(vision_features=batch_embeddings)
                    else:
                        with autocast():
                            _, batch_embeddings, _, _ = alignment_model(text_features=batch_embeddings)
                if hasattr(autoencoder, 'encode'):
                    with autocast():
                        activations = autoencoder.encode(batch_embeddings)[:, target_indices]
                else:
                    if modality == 'vision':
                        with autocast():
                            activations = autoencoder.encode_v(batch_embeddings)[:, target_indices]
                    else:
                        with autocast():
                            activations = autoencoder.encode_t(batch_embeddings)[:, target_indices]        

                for j, idx in enumerate(target_indices):
                    all_target_activations[idx].append(activations[:, j].cpu())
        
        results = {}

        for target_idx in target_indices:
            target_activations = torch.cat(all_target_activations[target_idx])
            top_k_vals, top_k_indices = torch.topk(target_activations.float(), top_k)
            # interpretation_data = []
            # for val, idx in zip(top_k_vals, top_k_indices):
            #     reference = references[idx]
            #     interpretation_data.append(os.path.split(reference)[1])

            results[target_idx] = interpretation_data
            if top_k_vals[0] > 0:
                interpretation_data = []
                for val, idx in zip(top_k_vals, top_k_indices):
                    reference = references[idx]
                    interpretation_data.append(os.path.split(reference)[1])

                results[target_idx] = interpretation_data
            else:
                continue
        
        return results

    target_indices = np.random.choice(args.input_dim*args.hidden_ratio, size=args.num_targets, replace=False)


    image_results = get_multiple_top_activations(target_indices, vision_embeddings, image_paths, modality='vision')
    text_results = get_multiple_top_activations(target_indices, text_embeddings, texts, modality='text')

    for target_idx in tqdm(target_indices, desc="Saving results"):
        if target_idx not in image_results or target_idx not in text_results:
            continue
        image_save_dir = os.path.join(f'./interpret_images/interpret_images_{args.sae_type}_{args.model_type}_{args.topk}', str(target_idx))
        if not os.path.exists(image_save_dir):
            os.makedirs(image_save_dir, exist_ok=True)
        
        
        for image_name in image_results[target_idx]:
            image_path = os.path.join('../CC3M/cc3m_jpg', image_name)
            image_save_path = os.path.join(image_save_dir, image_name)
            os.system(f'cp {image_path} {image_save_path}')
        
        with open(os.path.join(image_save_dir, 'text_interpretation.txt'), 'w') as f:
            f.write('\n'.join(text_results[target_idx]))

if __name__ == '__main__':
    main()
