import os 
import gc
import json
from packaging import version
import pathlib

from tqdm import tqdm
import tokenizers
import transformers

import torch
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset, DataLoader, Subset
from accelerate import PartialState
from accelerate.utils import gather_object

from ppl_llava_trainer import LLaVATrainer_Custom
from tinyllava.training_recipe import TrainingRecipeFactory
from tinyllava.utils import *
from tinyllava.model import *
from tinyllava.data.dataset import make_supervised_data_module

from mir_utils import calculate_fid_pytorch, calculate_fid_batched, batch_replace_outliers_with_median_l2_preferred

IS_TOKENIZER_GREATER_THAN_0_14 = version.parse(tokenizers.__version__) >= version.parse('0.14')


def load_settings(model_arguments, data_arguments, training_arguments):
    model_arguments.tune_type_connector = training_arguments.tune_type_connector
    model_arguments.tune_type_llm = training_arguments.tune_type_llm
    model_arguments.tune_type_vision_tower = training_arguments.tune_type_vision_tower
    model_arguments.image_aspect_ratio = data_arguments.image_aspect_ratio

    model_args = {}
    model_args['llm'] = _load_llm_settings(model_arguments)
    model_args['vision_tower'] = _load_vision_settings(model_arguments)
    model_args['connector'] = _load_connector_settings(model_arguments) 
    return model_args

def _load_llm_settings(model_arguments):
    llm_args = {}
    llm_args['model_name_or_path'] = model_arguments.model_name_or_path
    llm_args['cache_dir'] = model_arguments.cache_dir
    llm_args['attn_implementation'] = model_arguments.attn_implementation # flash_attention_2 only supports torch.float16 and torch.bfloat16 dtypes
    return llm_args

def _load_vision_settings(model_arguments):
    vision_args = {}
    vision_args['model_name_or_path'] = model_arguments.vision_tower.split(':')[-1]
    if model_arguments.vision_tower2 != '':
        vision_args['model_name_or_path2'] = model_arguments.vision_tower2.split(':')[-1]
    return vision_args

def _load_connector_settings(model_arguments):
    connector_args = {}
    connector_args['connector_type'] = model_arguments.connector_type
    return connector_args


def train():
    parser = transformers.HfArgumentParser(
            (ModelArguments, DataArguments, TrainingArguments))
    model_arguments, data_arguments, training_arguments = parser.parse_args_into_dataclasses()
    
    logger_setting(getattr(training_arguments, 'output_dir', None))

    if training_arguments.pretrained: 
        print("Loading Pre-Trained Model!")
        
        training_recipe = TrainingRecipeFactory(training_arguments.training_recipe)(training_arguments) 
        model_args = load_settings(model_arguments, data_arguments, training_arguments)
        model_args = training_recipe.add_args(model_args)
        model_config = TinyLlavaConfig()
        model_config.load_from_config(model_arguments)
        model = TinyLlavaForConditionalGeneration(model_config)
        # load pretrained checkpoint
        if training_arguments.pretrained_model_path is not None:
            model = training_recipe.load(model, model_args)
        else:
            model.load_llm(**model_args['llm'])
            model.load_vision_tower(**model_args['vision_tower'])
            model.load_connector(**model_args['connector'])

        model = training_recipe(model)
        model.config.use_cache = False
        model.config.image_aspect_ratio = data_arguments.image_aspect_ratio
        tokenizer = model.tokenizer
        data_arguments.image_processor = model.vision_tower._image_processor
        data_arguments.is_multimodal = True
        
    else: 
        print("Loading Fine-Trained Model!")
        disable_torch_init()
        model_path = os.path.expanduser(model_arguments.model_name_or_path)
        model, tokenizer, image_processor, context_len = load_pretrained_model(model_path)
        data_arguments.image_processor = image_processor
        data_arguments.is_multimodal = True
        
    log_trainable_params(model)  

    if training_arguments.ds_metric is None:
        raise ValueError("ds_metric cannot be None. Please choose from ['ppl', 'mir', 'attention_svd', 'last_layer_act']")

    if training_arguments.ds_metric == "ppl":
        data_module = make_supervised_data_module(tokenizer=tokenizer,
                                              data_args=data_arguments,
                                              return_eval_dataset=True)
        trainer = LLaVATrainer_Custom(model=model, 
                           tokenizer=tokenizer,
                           args=training_arguments,
                           **data_module)
        
        print("Running Eval!")
        metrics = trainer.evaluate()
        trainer.save_ppl_values()

    elif training_arguments.ds_metric == "mir":
        data_module = make_supervised_data_module(tokenizer=tokenizer,
                                              data_args=data_arguments)
        print("Computing MIR values!")
        distributed_state = PartialState()
        model.to(distributed_state.device)

        dataset = data_module['train_dataset']
        collate_fn = data_module['data_collator']

        with distributed_state.split_between_processes(list(range(len(dataset)))) as subset_indices:
            subset = Subset(dataset, subset_indices)
        
            dataloader = DataLoader(
                subset, 
                batch_size=training_arguments.per_device_eval_batch_size, 
                shuffle=False, 
                num_workers=8, 
                pin_memory=True, 
                collate_fn=collate_fn
            )

            mir_list = []
            with torch.inference_mode():
                for idx, inputs in tqdm(enumerate(dataloader), total=len(dataloader)):
                    model.half()
                    ids = inputs['unique_indices']
                    input_ids = inputs['input_ids'].to(distributed_state.device)
                    image_start_idxs = torch.where(input_ids == IMAGE_TOKEN_INDEX)[1]
                    images = inputs['images'].to(dtype=torch.float16, device='cuda', non_blocking=True)
                    outputs = model.generate(
                        input_ids,
                        images=images,
                        do_sample=False,
                        num_beams=1,
                        max_new_tokens=1,
                        output_attentions=True,
                        output_hidden_states=True,
                        return_dict_in_generate=True,
                        use_cache=True,
                    )

                    hidden_states = outputs.hidden_states
                    latent_hidden_states = [hidden_state.squeeze() for hidden_state in hidden_states[0]]

                    vision_hidden_states, text_hidden_states = [], [] # 33 * (4, 576, 4096) for vision
                    for latent in latent_hidden_states:
                        vision_hidden_states.append(torch.cat([latent[idx,image_start_idx:image_start_idx+576,:].unsqueeze(0).detach() for idx, image_start_idx in enumerate(image_start_idxs)], dim=0))
                        text_hidden_states.append(torch.cat([latent[idx,image_start_idx+576:,:].unsqueeze(0).detach() for idx, image_start_idx in enumerate(image_start_idxs)], dim=0))
                    
                    vision_features = vision_hidden_states[0]
                    text_features = text_hidden_states[0]
                    
                    # Text-Centric Normalization
                    scale_factor = 1. / text_features.norm(p=2, dim=-1).mean(1)
                    scale_factor = scale_factor.unsqueeze(-1).unsqueeze(-1)
                    vision_features = scale_factor * vision_features
                    text_features = scale_factor * text_features

                    # 3-Sigma Outlier Removal
                    vision_features = batch_replace_outliers_with_median_l2_preferred(vision_features.float())
                    text_features = batch_replace_outliers_with_median_l2_preferred(text_features.float())

                    mirs = calculate_fid_batched(vision_features, text_features)
                    for data in zip(ids, mirs):
                        mir_list.append({"id": data[0], "mir": math.log(data[1].abs().item(), 10)})

                    del input_ids, images, outputs, hidden_states, latent_hidden_states
                    del vision_hidden_states, text_hidden_states, vision_features, text_features, scale_factor, mirs
                    gc.collect()
                    torch.cuda.empty_cache()
                    
            save_path = os.path.join(training_arguments.output_dir, f"mir_values_baseline_{distributed_state.local_process_index}.json")
            mir_dict = {}
            for i in mir_list:
                mir_dict[i['id']] = i['mir']
            print("File saved at: ", save_path)
            with open(save_path, "w") as f:
                json.dump(mir_dict, f, indent=4)

    elif training_arguments.ds_metric == "attention_svd":
        data_module = make_supervised_data_module(tokenizer=tokenizer,
                                              data_args=data_arguments)
        print("Computing MIR values!")
        distributed_state = PartialState()
        model.to(distributed_state.device)

        dataset = data_module['train_dataset']
        collate_fn = data_module['data_collator']
        
        def attention_scores(k_proj, q_proj):
            d_k = q_proj.shape[-1]  # Hidden dimension
            attention_scores = torch.matmul(q_proj, k_proj.transpose(-2, -1)) / (d_k ** 0.5)
            attention_scores = torch.softmax(attention_scores, dim=-1)
            return attention_scores

        with distributed_state.split_between_processes(list(range(len(dataset)))) as subset_indices:
            subset = Subset(dataset, subset_indices)
        
            dataloader = DataLoader(
                subset, 
                batch_size=training_arguments.per_device_eval_batch_size, 
                shuffle=False, 
                num_workers=8, 
                pin_memory=True, 
                collate_fn=collate_fn
            )

            attn_svd_list = {}
            with torch.inference_mode():
                for idx, inputs in tqdm(enumerate(dataloader), total=len(dataloader)):
                    model.half()
                    ids = inputs['unique_indices']
                    input_ids = inputs['input_ids'].to(distributed_state.device)
                    image_start_idxs = torch.where(input_ids == IMAGE_TOKEN_INDEX)[1]
                    images = inputs['images'].to(dtype=torch.float16, device='cuda', non_blocking=True)
                    model.config.output_attentions = True
                    
                    attn_outputs = []
                    proj_cache = {}
                    def hook_fn(name, module, input, output):
                        if 'q_proj' in proj_cache.keys() and 'k_proj' in proj_cache.keys():
                            if len(proj_cache['q_proj']) < 24:
                                if 'q_proj' in name.lower():
                                    proj_cache['q_proj'].append(output)
                            if len(proj_cache['k_proj']) < 24:
                                if 'k_proj' in name.lower():
                                    proj_cache['k_proj'].append(output)

                            if len(proj_cache['q_proj']) == 24 and len(proj_cache['k_proj']) == 24:
                                attn_scores = 0
                                for q_proj, k_proj in zip(proj_cache['q_proj'], proj_cache['k_proj']):
                                    attn_scores += attention_scores(q_proj, k_proj)
                                cm_attn_scores = attn_scores[:, 612:, 35:611]       # selecting the tokens 
                                attn_outputs.append(cm_attn_scores)
                                attn_scores = None
                                proj_cache.clear()
                        else:   
                            if 'q_proj' in name.lower():
                                proj_cache['q_proj'] = [output]
                            elif 'k_proj' in name.lower():
                                proj_cache['k_proj'] = [output]
                            else: 
                                raise ValueError("Incorrect Layer passed")
                    count = 0
                    for name, module in model.named_modules():
                        if ('language_model.model' in name.lower() and 'q_proj' in name.lower()) or ('language_model.model' in name.lower() and 'k_proj' in name.lower()):  # Find attention layers
                            module.register_forward_hook(lambda mod, inp, out, n=name: hook_fn(n, mod, inp, out))

                    _ = model(
                        input_ids=input_ids,
                        images=images,
                        output_attentions=True,
                        return_dict=True,
                        use_cache=False
                    )
                    
                    if training_arguments.method is None or training_arguemnts.k is None: 
                        raise ValueError ("Argument 'method' cannot be None, it has to be selected from ['svd-full, 'partial-svd']")
                    
                    if training_arguments.method == "svd-full":
                        metrics = torch.linalg.svdvals(attn_outputs[0].float()) 
                        if training_arguments.k >= 1:
                            metrics = metrics[:, :k]
                        elif training_arguments.k == -1:
                            print("Using all the singular values.")
                        else:
                            raise ValueError("Invalid value of 'k' is passed.")

                    elif training_arguments.method = "partial-svd":
                        _, metrics, _ = torch.svd_lowrank(attn_outputs[0].float(), q=training_arguments.k)    # ill-defined input matrix for q>1 instead compute full-svd
                        
                    # Uncomment for using the maximum value from the attention matrices
                    # elif method == "max":
                    #     attn_flat = attn_outputs[0].reshape(attn_outputs[0].shape[0], -1)
                    #     metrics = torch.topk(attn_flat, k=k, dim=1)[0]
                    #     if k > 1:
                    #         metrics = torch.sum(metrics, dim=1)
                    #     attn_flat = None
                    
                    else:
                        raise ValueError("Incorrect value of 'method' passed!")

                    for id, metric in zip(ids, metrics):
                        if training_arguments.k == 1:
                            attn_svd_list[id] = metric.item()
                        else: 
                            attn_svd_list[id] = metric.tolist()

                    ids = input_ids = image_start_idxs = images = attn_outputs = proj_cache = metrics = None        # to prevent OOM Error

            save_path = os.path.join(training_arguments.output_dir, f"fine-tuned_all_layers_attn_{method}_top-{k}_{distributed_state.local_process_index}.json")
            print("File saved at: ", save_path)
            with open(save_path, "w") as f:
                json.dump(attn_svd_list, f, indent=4)
    
    elif training_arguments.ds_metric == "last_layer_act":
        data_module = make_supervised_data_module(tokenizer=tokenizer,
                                              data_args=data_arguments,
                                              last_layer_act=True)
        print("Computing Last-Layer Activations!")
        distributed_state = PartialState()
        model.to(distributed_state.device)

        dataset = data_module['train_dataset']
        collate_fn = data_module['data_collator']

        with distributed_state.split_between_processes(list(range(len(dataset)))) as subset_indices:
            subset = Subset(dataset, subset_indices)
        
            dataloader = DataLoader(
                subset, 
                batch_size=training_arguments.per_device_eval_batch_size, 
                shuffle=False, 
                num_workers=8, 
                pin_memory=True, 
                collate_fn=collate_fn
            )

            msa_embed_list = {}
            with torch.inference_mode():
                for idx, inputs in tqdm(enumerate(dataloader), total=len(dataloader)):
                    model.half()
                    ids = inputs['unique_indices']
                    input_ids = inputs['input_ids'].to(distributed_state.device)
                    image_start_idxs = torch.where(input_ids == IMAGE_TOKEN_INDEX)[1]
                    images = inputs['images'].to(dtype=torch.float16, device='cuda', non_blocking=True)
                    model.config.output_attentions = True
                    
                    msa_embeds = []
                    def hook_fn(name, module, input, output):
                        sim_act = nn.Tanh()
                        msa_act = output

                        # according to https://github.com/G-JWLee/COINCIDE_code/blob/master/COINCIDE_cluster/tinyllava/eval/score/coincide/extract_embed.py#L304
                        prompt_len=31              
                        msa_act_v = torch.mean(msa_act[:, prompt_len:prompt_len+729], dim=1)
                        msa_act_v_np = F.normalize(sim_act(msa_act_v), dim=-1).detach().cpu().numpy()  

                        msa_act_l_np = np.zeros_like(msa_act_v_np)
                        for batch_idx in range(len(inputs['unique_indices'])):
                            i_lang_len = inputs['language_length'][batch_idx]
                            if i_lang_len != 0:
                                msa_act_l = torch.mean(msa_act[batch_idx, prompt_len+729:prompt_len+729 + i_lang_len], dim=0)
                                msa_act_l_np[batch_idx] = F.normalize(sim_act(msa_act_l), dim=-1).detach().cpu().numpy()

                        msa_embed = (np.concatenate([msa_act_v_np, msa_act_l_np], axis=-1) / np.sqrt(2))
                        msa_embeds.append(msa_embed)
                    for name, module in model.named_modules():
                        if 'language_model.model' in name.lower() and '23.self_attn.o_proj' in name.lower():  # Find attention layers
                            module.register_forward_hook(lambda mod, inp, out, n=name: hook_fn(n, mod, inp, out))

                    _ = model(
                        input_ids=input_ids,
                        images=images,
                        output_attentions=True,
                        return_dict=True,
                        use_cache=False
                    )

                    for id, msa_embed in zip(ids, msa_embeds[0]):
                        msa_embed_list[id] = msa_embed.tolist()

                    ids = input_ids = image_start_idxs = images = msa_embeds = None     # to prevent OOM Error
            
            save_path = os.path.join(training_arguments.output_dir, f"last_layer_msa_embeds_{distributed_state.local_process_index}.json")
            print("File saved at: ", save_path)
            with open(save_path, "w") as f:
                json.dump(msa_embed_list, f, indent=4)
    else:
        raise ValueError(f"Invalid value: {ds_metric} for ds_metric passed, can only be chosen from ['ppl', 'mir', 'attention_svd', 'last_layer_act']")
    
    # training_recipe.save(model, trainer)

if __name__ == "__main__":
    train()