import argparse
import torch
import torch.nn as nn
import yaml
from tqdm import tqdm
import transformers
from torch.utils.data import DataLoader
from tinyllava.data.dataset import make_supervised_data_module
from tinyllava.utils import *
from tinyllava.model.load_model import load_pretrained_model
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import json

def attention_scores(k_proj, q_proj):
    d_k = q_proj.shape[-1]  # Hidden dimension
    attention_scores = torch.matmul(q_proj, k_proj.transpose(-2, -1)) / (d_k ** 0.5)
    attention_scores = torch.softmax(attention_scores, dim=-1)
    return attention_scores
    
def attn_svd(data, model, training_arguments):
    model.cuda()
    model.eval()

    attn_svd_vals = {}

    train_loader = DataLoader(
        data['train_dataset'],
        batch_size=4,
        collate_fn=data['data_collator'],
        shuffle=True,
        num_workers=4,
        pin_memory=True
    )
    with torch.no_grad():
        for _,inputs in tqdm(enumerate(train_loader), total=len(train_loader)):
            model.half()
            ids = inputs['unique_indices']
            input_ids = inputs['input_ids'].cuda()
            image_start_idxs = torch.where(input_ids == IMAGE_TOKEN_INDEX)[1]
            images = inputs['images'].to(dtype=torch.float16, device='cuda', non_blocking=True)
            model.config.output_attentions = True
            
            attn_outputs = []
            proj_cache = {}
            def hook_fn(name, module, input, output):
                if 'q_proj' in proj_cache.keys() and 'k_proj' in proj_cache.keys():
                    if len(proj_cache['q_proj']) < 24:
                        if 'q_proj' in name.lower():
                            proj_cache['q_proj'].append(output)
                    if len(proj_cache['k_proj']) < 24:
                        if 'k_proj' in name.lower():
                            proj_cache['k_proj'].append(output)

                    if len(proj_cache['q_proj']) == 24 and len(proj_cache['k_proj']) == 24:
                        attn_scores = 0
                        for q_proj, k_proj in zip(proj_cache['q_proj'], proj_cache['k_proj']):
                            attn_scores += attention_scores(q_proj, k_proj)
                        cm_attn_scores = attn_scores[:, 612:, 35:611]                               # selecting the tokens 
                        attn_outputs.append(cm_attn_scores)
                        attn_scores = None
                        proj_cache.clear()
                else:   
                    if 'q_proj' in name.lower():
                        proj_cache['q_proj'] = [output]
                    elif 'k_proj' in name.lower():
                        proj_cache['k_proj'] = [output]
                    else: 
                        raise ValueError("Incorrect Layer passed")
            count = 0
            for name, module in model.named_modules():
                if ('language_model.model' in name.lower() and 'q_proj' in name.lower()) or ('language_model.model' in name.lower() and 'k_proj' in name.lower()):  # Find attention layers
                    module.register_forward_hook(lambda mod, inp, out, n=name: hook_fn(n, mod, inp, out))

            _ = model(
                input_ids=input_ids,
                images=images,
                output_attentions=True,
                return_dict=True,
                use_cache=False
            )
            
            if training_arguments.method == "svd-full":
                metrics = torch.linalg.svdvals(attn_outputs[0].float())   
                num_nans = torch.isnan(metrics).any()
                num_infs = torch.isinf(metrics).any()
                if num_nans:
                    print("Found Nans !!!!")   
                if num_infs:
                    print("Found Infs !!!!")   
                if training_arguments.k >= 1:
                    metrics = metrics[:, :training_arguments.k]
            elif training_arguments.method == "partial_svd":
                _, metrics, _ = torch.svd_lowrank(attn_outputs[0].float(), q=1)

            # elif training_arguments.method == "max":
            #     attn_flat = attn_outputs[0].reshape(attn_outputs[0].shape[0], -1)
            #     metrics = torch.topk(attn_flat, k=k, dim=1)[0]
            #     if k > 1:
            #         metrics = torch.sum(metrics, dim=1)
            #     attn_flat = None
            else:
                raise ValueError("Incorrect value of method passed!")
            
            for id, metric in zip(ids, metrics):
                if training_arguments.k == 1:
                    attn_svd_vals[id] = metric.item()
                else: 
                    attn_svd_vals[id] = metric.tolist()

            ids = input_ids = image_start_idxs = images = attn_outputs = proj_cache = metrics = None
            torch.cuda.empty_cache()

    return attn_svd_vals


def loss(data, model):
    """compute last hidden states for a data_module"""
    model.cuda()
    model.eval()
    
    losses = {}
    
    train_loader = DataLoader(
        data['train_dataset'], 
        batch_size=16, 
        collate_fn=data['data_collator'],
        shuffle=True,
        num_workers=4, 
        pin_memory=True
    )
    with torch.no_grad():
        for _,inputs in tqdm(enumerate(train_loader), total=len(train_loader)):
            ignore_index = -100
            unique_indices = inputs.pop("unique_indices")

            assert "inputs_embeds" not in inputs
            images = inputs["images"].cuda()
            input_ids = inputs["input_ids"].cuda()
            labels = inputs["labels"].cuda()
            attention_mask = inputs["attention_mask"].cuda()
            if inputs.get("past_key_values") is not None and inputs.get("position_ids") is not None:
                past_key_values = inputs.get("past_key_values").cuda()
                position_ids = inputs.get("position_ids").cuda()
            else: 
                past_key_values = inputs.get("past_key_values")
                position_ids = inputs.get("position_ids")            

            (
                input_ids,
                position_ids,
                attention_mask,
                past_key_values,
                inputs_embeds,
                labels,
            ) = model.prepare_inputs_labels_for_multimodal(
                input_ids, position_ids, attention_mask, past_key_values, labels, images
            )

            inputs["input_ids"] = input_ids
            inputs["position_ids"] = position_ids
            inputs["attention_mask"] = attention_mask
            inputs["past_key_values"] = past_key_values
            inputs["inputs_embeds"] = inputs_embeds
            inputs["labels"] = labels

            outputs = model(**inputs)
            
            # loss = outputs.loss.detach().cpu().item()

            vocab_size = outputs.logits.size(-1)

            logits = outputs.logits.float()
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()

            shift_logits = shift_logits.view(-1, vocab_size)
            shift_labels = shift_labels.view(-1)
            shift_labels = shift_labels.to(shift_logits.device)

            token_level_loss = nn.functional.cross_entropy(
                shift_logits, shift_labels, ignore_index=ignore_index, reduction="none"
            )

            batch_size = labels.size(0)
            seq_length = labels.size(1) - 1  
            token_level_loss = token_level_loss.view(batch_size, seq_length)
            per_sequence_loss = token_level_loss.sum(dim=1)  

            total_active_tokens = (shift_labels != ignore_index).sum()  
            overall_batch_loss = token_level_loss.sum() / total_active_tokens  

            active_loss_mask = (shift_labels != ignore_index).view(batch_size, -1)
            active_tokens_per_sequence = active_loss_mask.sum(dim=1)  
            normalized_per_sequence_loss = per_sequence_loss / active_tokens_per_sequence

            images = input_ids = labels = attention_mask = inputs_embeds = position_ids = past_key_values = inputs = outputs = None
            vocab_size = logits = shift_logits = shift_labels = token_level_loss = batch_size = seq_length = None
            per_sequence_loss = total_active_tokens = overall_batch_loss = active_loss_mask = active_tokens_per_sequence = None 
            torch.cuda.empty_cache()

            if _==1 or (_!=0 and _%10000 == 0): # report progress
                 if torch.distributed.is_initialized():
                    if torch.distributed.get_rank()==0:
                        print(f"***** Predict-Progress -- {_} DONE !")
            
            batch_losses = {uid: loss for uid, loss in zip(unique_indices, normalized_per_sequence_loss)}
            
            losses.update(batch_losses)
            normalized_per_sequence_loss = batch_losses = None 

    return losses

def main():

    parser = transformers.HfArgumentParser(
        (ModelArguments, DataArguments, TrainingArguments))
    model_arguments, data_arguments, training_arguments = parser.parse_args_into_dataclasses()

    disable_torch_init()
    model_path = os.path.expanduser(model_arguments.model_name_or_path)
    model, tokenizer, image_processor, context_len = load_pretrained_model(model_path)

    data_arguments.image_processor = image_processor
    data_arguments.is_multimodal = True

    data = make_supervised_data_module(tokenizer=tokenizer, data_args=data_arguments)

    attn_svd_dir = os.path.join(model_arguments.model_name_or_path.split("/checkpoint-")[0], "attn_svd_files")
    loss_dir = os.path.join(model_arguments.model_name_or_path.split("/checkpoint-")[0], "loss_files")

    os.makedirs(attn_svd_dir, exist_ok=True)
    os.makedirs(loss_dir, exist_ok=True)

    attn_svd_file = os.path.join(attn_svd_dir, "checkpoint-" + model_arguments.model_name_or_path.split("/checkpoint-")[-1] + "_attn_svd.json")
    all_checkpoints_attn_svd = attn_svd(data=data, model=model, training_arguments=training_arguments)
    
    loss_file = os.path.join(loss_dir, "checkpoint-" + model_arguments.model_name_or_path.split("/checkpoint-")[-1] + "loss.json")
    all_checkpoints_loss = loss(data=data, model=model)

    with open(attn_svd_file, "w") as f:
        json.dump(all_checkpoints_attn_svd, f, indent=4)
    print(f"***** Attn_SVD values saved to {attn_svd_file}")   
    
    with open(loss_file, "w") as f:
        json.dump(all_checkpoints_loss, f, indent=4)
    print(f"***** Attn_SVD values saved to {loss_file}")   
                
if __name__ == '__main__':
    main()

