
import os
import logging
from typing import Dict

import torch
import transformers

from deepspeed import zero
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
import math


local_rank = None


def set_wandb_dir(abs_path):
    os.environ['WANDB_DIR'] = abs_path
    os.environ['WANDB_CACHE_DIR'] = os.path.join(abs_path, ".cache/")
    os.environ['WANDB_CONFIG_DIR'] = os.path.join(abs_path, ".config/")


def print_model_size(model):
    total_params = 0
    for name, param in model.named_parameters():
        total_params += param.numel()
        total_size = total_params/ (1024 ** 2)  # assuming 4 bytes per parameter
    print(f"Total Model Size: {total_size:.2f} MB")


def print_trainable_layers(model):
    for name, param in model.named_parameters():
        if param.requires_grad:
            print(name, param.shape, "requires_grad")

def print_param_device(model):
    module_name_ls = []
    for name, param in model.named_parameters():
        if name.split(".")[:3] not in module_name_ls:
            module_name_ls.append(name.split(".")[:3])
            print(name, param.device)

def print_layers(model):
    # save in a json file
    # save model.named_parameters()
    import json
    layers = {}
    trainable_layers = {}
    for name, param in model.named_parameters():
        layers[name] = param.shape
        if param.requires_grad:
            trainable_layers[name] = param.shape
    save_path = "./llava/backbone/detr_branch"
    with open(os.path.join(save_path,f'layers_{len(layers)}.json'), 'w') as f:
        json.dump(layers, f, indent=4)
    with open(os.path.join(save_path,f'trainable_layers_{len(layers)}.json'), 'w') as f:
        json.dump(trainable_layers, f, indent=4)
    print(f"save layers to {os.path.join(save_path,f'layers_{len(layers)}.json')}")


def get_answers_file_name(args, model_name, pretrain_mm_mlp_adapter=None, vm_pretrain_mm_mlp_adapter=None):
    if pretrain_mm_mlp_adapter is not None and vm_pretrain_mm_mlp_adapter is not None:
        #     --pretrain_mm_mlp_adapter "./llava/backbone/checkpoints_tune/llava-llama-2-7b-chat-DETR-pretrain-1000-tune-1/mm_projector.bin" \
        # special_name = "DETR-pretrain-1000-tune-1"
        try:
            special_name = pretrain_mm_mlp_adapter.split('/')[-2].split('chat-')[-1]
        except:
            special_name = pretrain_mm_mlp_adapter.split('/')[-2]
        try:
            special_name += f"-{vm_pretrain_mm_mlp_adapter.split('/')[-2].split('chat-')[-1]}"
        except:
            special_name += f"-{vm_pretrain_mm_mlp_adapter.split('/')[-2]}"
    else:
        special_name = ""

    file_name = ""
    
    if '7b' in model_name:
        file_name += '-7b'
    elif '13b' in model_name:
        file_name += '-13b'
    
    if 'lora' in model_name:
        file_name += '_lora'

    file_name += f'-bs{args.batch_size}'        
    file_name += f'-s{args.seed}'
    
    if special_name:
        file_name += f'-{special_name}'
    
    if args.cfg is not None:
        answers_file = args.question_file.replace('.json', f'{file_name}-cfg{args.cfg}.jsonl')
    else:
        answers_file = args.question_file.replace('.json', f'{file_name}-cfg.jsonl')
    
    return answers_file

def split_list(lst, n):
    """Split a list into n (roughly) equal-sized chunks"""
    chunk_size = math.ceil(len(lst) / n)  # integer division
    return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]


def get_chunk(lst, n, k):
    # convert dataframes to lists
    lst = list(lst)
    chunks = split_list(lst, n)
    return chunks[k]


def update_mm_projector(model, mm_projector_path):
    loaded_weights = torch.load(mm_projector_path, map_location="cpu")
    model.get_model().mm_projector.load_state_dict({"weight": loaded_weights["model.mm_projector.weight"], "bias": loaded_weights["model.mm_projector.bias"]})
    print(f"updated mm_projector weights from {mm_projector_path}")
    return model


def rank0_print(*args):
    if local_rank == 0:
        print(*args)


def maybe_zero_3(param, ignore_status=False, name=None):
    if hasattr(param, "ds_id"):
        if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
            if not ignore_status:
                logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}")
        with zero.GatheredParameters([param]):
            param = param.data.detach().cpu().clone()
    else:
        param = param.detach().cpu().clone()
    return param


# Borrowed from peft.utils.get_peft_model_state_dict
def get_peft_state_maybe_zero_3(named_params, bias):
    if bias == "none":
        to_return = {k: t for k, t in named_params if "lora_" in k}
    elif bias == "all":
        to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
    elif bias == "lora_only":
        to_return = {}
        maybe_lora_bias = {}
        lora_bias_names = set()
        for k, t in named_params:
            if "lora_" in k:
                to_return[k] = t
                bias_name = k.split("lora_")[0] + "bias"
                lora_bias_names.add(bias_name)
            elif "bias" in k:
                maybe_lora_bias[k] = t
        for k, t in maybe_lora_bias:
            if bias_name in lora_bias_names:
                to_return[bias_name] = t
    else:
        raise NotImplementedError
    to_return = {k: maybe_zero_3(v, name=k) for k, v in to_return.items()}
    return to_return


def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):
    to_return = {k: t for k, t in named_params if "lora_" not in k}
    if require_grad_only:
        to_return = {k: t for k, t in to_return.items() if t.requires_grad}
    to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
    return to_return


def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
    to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
    to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
    return to_return



def find_all_linear_names(model):
    cls = torch.nn.Linear
    lora_module_names = set()
    for name, module in model.named_modules():
        if isinstance(module, cls):
            names = name.split('.')
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])


    if 'lm_head' in lora_module_names: # needed for 16-bit
        lora_module_names.remove('lm_head')
    return list(lora_module_names)


def safe_save_model_for_hf_trainer(trainer: transformers.Trainer,
                                   output_dir: str):
    """Collects the state dict and dump to disk."""

    if getattr(trainer.args, "tune_mm_mlp_adapter", False):
        # Only save Adapter
        keys_to_match = ['mm_projector']
        if getattr(trainer.args, "use_im_start_end", False):
            keys_to_match.extend(['embed_tokens', 'embed_in'])

        weight_to_save = get_mm_adapter_state_maybe_zero_3(trainer.model.named_parameters(), keys_to_match)
        trainer.model.config.save_pretrained(output_dir)

        current_folder = output_dir.split('/')[-1]
        parent_folder = os.path.dirname(output_dir)
        if trainer.args.local_rank == 0 or trainer.args.local_rank == -1:
            if current_folder.startswith('checkpoint-'):
                mm_projector_folder = os.path.join(parent_folder, "mm_projector")
                os.makedirs(mm_projector_folder, exist_ok=True)
                torch.save(weight_to_save, os.path.join(mm_projector_folder, f'{current_folder}.bin'))
            else:
                torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))
        return

    if trainer.deepspeed:
        torch.cuda.synchronize()
        trainer.save_model(output_dir)
        return

    state_dict = trainer.model.state_dict()
    if trainer.args.should_save:
        cpu_state_dict = {
            key: value.cpu()
            for key, value in state_dict.items()
        }
        del state_dict
        trainer._save(output_dir, state_dict=cpu_state_dict)  # noqa


def smart_tokenizer_and_embedding_resize(
    special_tokens_dict: Dict,
    tokenizer: transformers.PreTrainedTokenizer,
    model: transformers.PreTrainedModel,
):
    """Resize tokenizer and embedding.

    Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
    """
    num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
    model.resize_token_embeddings(len(tokenizer))

    if num_new_tokens > 0:
        input_embeddings = model.get_input_embeddings().weight.data
        output_embeddings = model.get_output_embeddings().weight.data

        input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
            dim=0, keepdim=True)
        output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
            dim=0, keepdim=True)

        input_embeddings[-num_new_tokens:] = input_embeddings_avg
        output_embeddings[-num_new_tokens:] = output_embeddings_avg

