import shutil
from pathlib import Path
from typing import Dict, List

import hydra
import pandas as pd
import torch
from safetensors import safe_open
from safetensors.torch import save_file
from salience.utils import get_gpu_stats
from tqdm import tqdm

from lm_understanding.data import model_behavior_path
from lm_understanding.models import make_full_model_name
from lm_understanding.models.hf_model import (T5_MODELS, create_tokenizer,
                                              get_input_ids, model_class)
from lm_understanding.question_template import TemplateModelBehavior


def get_max_memory(stats, max_mem_buffer: int = 5, util_gpu_thresh: int = 90, verbose=True):
    """
    Get the max memory per GPU for loading a model

    stats: GPU memory and usage statistics
    max_mem_buffer: additive buffer for maximum memory, in MiB
    utIL_gpu_thresh: if gpu-util is greater than this value, allocate 0 memory to this gpu
    """

    max_memory = {i: f"{stats[i]['mem_free'] - max_mem_buffer}MiB" for i in stats}
    for i in stats:
        if stats[i]['util_gpu'] > util_gpu_thresh: max_memory[i] = "0MiB"
    
    if verbose: print(max_memory)

    return max_memory



def get_model(
    model_name: str = None,
    auto_device_map: bool = False,
    device: int = None,
    device_map = None,
):
    """
    Get model, put in eval mode. 
    """
    
    model_name = make_full_model_name(model_name)
    if auto_device_map:
        assert device is None and device_map is None
        if torch.cuda.device_count() > 1:
            stats = get_gpu_stats(verbose=True)
            max_memory = get_max_memory(stats, verbose=True)
            model = model_class(model_name).from_pretrained(model_name, output_attentions=True, device_map='auto', max_memory=max_memory)
        else:
            model = model_class(model_name).from_pretrained(model_name, output_attentions=True, device_map='auto')
    elif device is not None:
        assert device_map is None
        model = model_class(model_name).from_pretrained(model_name, output_attentions=True)
        model.to(device)
    elif device_map is not None:
        model = model_class(model_name).from_pretrained(model_name, output_attentions=True, device_map=device_map)
    model.eval()

    return model



def get_model_behavior_questions(
        path_model_behavior = None,
        split_name: str = 'train',
) -> pd.Series:
    
    if path_model_behavior.exists():
         model_behavior = TemplateModelBehavior.load(path_model_behavior)
         df_model_behavior = model_behavior.df
         df_questions = df_model_behavior[df_model_behavior.split == split_name].question_text
    else:
        raise Exception(f"Model behavior not found at {path_model_behavior}")
    
    return df_questions



def get_inputs(
        df_questions: pd.Series = None,
        prompt: str = None,
        tokenizer = None,
        nsamples: int = None,
        question_ids: List[int] = None,
        device: str = None,
        
):
    """
    Get encoder input and decoder input ids associated with natural language prompt at path.
    Prepend with the model specific pre-prompt.
    """

    # Specify one of nsamples (random sample) or csv_inds (indices in csv)
    assert (nsamples is None) or (question_ids is None)
    assert (not nsamples is None) or (not question_ids is None)
    
    if nsamples is not None:
        questions = df_questions.sample(nsamples)
        question_ids = questions.index.tolist()
        questions = questions.tolist()
    else:
        questions = df_questions[question_ids].to_list()

    prompts = [prompt.replace("[question]", q) for q in questions]

    inputs = tokenizer(prompts, return_tensors='pt', max_length=350, padding='max_length')
    inputs = {k: v.to(device) for k,v in inputs.items()}

    return question_ids, prompts, inputs



def view_outputs(
        outputs = None,
        tokenizer = None,
):
    """
    For human consumption
    Decode output and print it.
    Get the probabilities associated with "Yes" and "No" token and print it.
    """

    response = tokenizer.batch_decode(torch.argmax(outputs['logits'].squeeze(dim=1), axis=1))
    print(response)

    index_yes = tokenizer.convert_tokens_to_ids(tokenizer.tokenize('Yes'))[0]
    index_no = tokenizer.convert_tokens_to_ids(tokenizer.tokenize('No'))[0]

    probs = torch.softmax(outputs['logits'].squeeze(dim=1), dim=-1)
    print(probs[:, [index_yes,index_no]])


def get_flap(
    outputs,
    ind_token: int,
    outputs_key: str = 'cross_attentions',
) -> torch.Tensor:
    """
    Get final layer attention pattern (FLAP) of the first token on the encoder
    Returns tensor of shape (batch, num_heads, sequence_len

    outputs: Output dictionary from model.
    token_ind: Token index whose attention to return. e.g. 
        * 0 corresponding to the initial decoder input id for encoder-decoder models
        * -1 for the last token in decoder-only models
        * 0 corresponding to the CLS token for BERT classifier
    outputs_key: Key for attention pattern in outputs model
        * 'cross_attentions' for encoder-decoder
        * 'attentions' for decoder-only
    """

    ind_batch = slice(None) # all batches
    ind_head = slice(None) # all heads
    ind_layer = -1 # final layer
    # Tuple of torch.FloatTensor (one for each layer) of shape (batch_size, num_heads, sequence_length, sequence_length)
    flap = outputs[outputs_key][ind_layer][ind_batch, ind_head, ind_token, :]

    return flap
    


def get_grad(
    outputs,
    inputs_embeds,
    ind_token: int,
    retain_graph: bool = False,
) -> torch.Tensor:
    """
    Get gradient

    Output logits should have the logits over the last dimension
    """

    out_to_diff = outputs['logits'][..., ind_token]
    inputs_grad = torch.autograd.grad(
        outputs=out_to_diff,
        inputs=inputs_embeds,
        grad_outputs=torch.ones_like(out_to_diff),
        retain_graph=retain_graph,
    )[0]

    return inputs_grad



def get_grad_norm(
        inputs_grad: torch.Tensor,
        p: int = 2,
) -> torch.Tensor:
    """
    Get gradient norm
    p: order of norm
    """

    # inputs_grad_norm = torch.square(inputs_grad).sum(dim=-1)
    inputs_grad_norm = torch.norm(inputs_grad, p=p, dim=-1)

    return inputs_grad_norm



def get_ixg(
        inputs_grad: torch.Tensor,
        inputs_embeds: torch.Tensor,
) -> torch.Tensor:
    """
    Get IxG
    """

    inputs_ixg = torch.einsum("b s h, b s h -> b s", inputs_embeds, inputs_grad)

    return inputs_ixg



def get_integrated_grad(
        model,
        inputs_embeds,
        attention_mask,
        ind_token: int,
        baseline_embed = None,
        steps: int = 50,
        kwargs = {},
):
    """
    inputs_embeds: input embeddings of shape (batch, seq_len, hidden_dim)
    kwargs: should specify decoder_input_ids if it is an encoder-decoder model
    """

    assert steps > 0
    
    if baseline_embed is None:
        baseline_embed = 0 * inputs_embeds
    assert(baseline_embed.shape == inputs_embeds.shape)

    BATCH_SIZE, SEQ_LEN, HIDDEN_DIM = inputs_embeds.shape

    # Get the path from baseline to input 
    path_embeds = [baseline_embed + float(i)/steps * (inputs_embeds - baseline_embed) for i in range(0,steps+1)]

    # Get gradients
    inputs_grads = []
    steps_skipped = 0
    
    for i in range(len(path_embeds)):
        outputs = model(
            inputs_embeds=path_embeds[i],
            attention_mask=attention_mask,
            # decoder_input_ids=decoder_input_ids,
            **kwargs,
            output_attentions=False,
            )
        inputs_grad = get_grad(outputs, inputs_embeds=path_embeds[i], ind_token=ind_token, retain_graph=False)
        assert inputs_grad.shape == (BATCH_SIZE, SEQ_LEN, HIDDEN_DIM)
        if torch.isnan(inputs_grad).any():
            print(f"Skipping nan-containing grad at path index {i}.")
            steps_skipped += 1
            continue
        inputs_grads.append(inputs_grad)

        del outputs
        torch.cuda.empty_cache()

    for param in model.parameters(): assert param.grad is None

    # Integrate along path
    inputs_grads = torch.stack(inputs_grads, dim=0)
    assert inputs_grads.shape == (steps+1-steps_skipped, BATCH_SIZE, SEQ_LEN, HIDDEN_DIM)
    inputs_grads = (inputs_grads[:-1] + inputs_grads[1:]) / 2.0
    integrated_grad = (inputs_embeds - baseline_embed) * inputs_grads.mean(dim=0)

    assert integrated_grad.shape == (BATCH_SIZE, SEQ_LEN, HIDDEN_DIM)

    return integrated_grad



def split_list(lst, batch_size):
    if len(lst) % batch_size != 0:
        npad = batch_size - (len(lst) % batch_size)
        lst = lst + [None] * npad
    nested_lst = [lst[i:i+batch_size] for i in range(0, len(lst), batch_size)]
    nested_lst[-1] = [qi for qi in nested_lst[-1] if qi is not None]
    return nested_lst



def tensor_to_dict(tensor: torch.Tensor, ids: List[int], move_to_cpu: bool = True) -> Dict[int, torch.Tensor]:
    """
    move_to_cpu to avoid "Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again"
    """

    n_batch = len(ids)
    assert tensor.shape[0] == n_batch
    return_dict = {}
    tensor_spl = torch.tensor_split(tensor.detach(), n_batch, dim=0)
    for i,id in enumerate(ids):
        val = tensor_spl[i].squeeze(dim=0)
        if move_to_cpu: val = val.to("cpu")
        return_dict[str(id)] = val
    return return_dict



def collate_batches(path_save, salience_method):

    salience_aggregated = {}

    for path_file in list((path_save / salience_method).iterdir()):
        
        with safe_open(path_file, framework="pt", device="cpu") as f:
            for key in f.keys():
                if key in salience_aggregated: raise Exception()
                salience_aggregated[key] = f.get_tensor(key)
    
    save_file(salience_aggregated, path_save / f'{salience_method}.safetensors')

    # shutil.rmtree(path_save / salience_method)



@hydra.main(config_path='../../config', config_name='get_salience.yaml', version_base='1.2')
def main(config):

    if config.verbose: print(config)

    model_path = make_full_model_name(config.model.name)
    
    path_proj = Path(__file__).parent.parent
    path_model_behavior = path_proj / model_behavior_path(config) / config.template_id

    path_save = path_proj / "scripts/salience/salience_map_results" / config.model.name / config.task.name / config.template_id
    for salience_map in [
        "flap",
        'inputs_grad_norm_yes',
        "ixg_norm_yes",
        "inputs_intgrad_norm_yes",
        "ixig_norm_yes",
        ]:
        path_save.mkdir(exist_ok=True, parents=True)
        (path_save / salience_map).mkdir(exist_ok=True, parents=True)

    df_train = get_model_behavior_questions(path_model_behavior, split_name='train')
    question_idss = split_list(df_train.index.tolist(), config.batch_size)

    model = get_model(model_path, auto_device_map=True)
    if config.verbose: print(f"{model.hf_device_map=}")
    # DEVICE should be that of first layer of model; inputs should go on same device
    DEVICE = "cuda:0"

    tokenizer = create_tokenizer(model_path)

    index_yes = get_input_ids(['Yes'], tokenizer).pop()
    # index_no = get_input_ids(['No'], tokenizer).pop()

    for batch_num in tqdm(range(len(question_idss)), desc="Getting salience patterns"):

        if batch_num < 22: continue

        question_ids = question_idss[batch_num]

        model.zero_grad()

        _, _, inputs = get_inputs(
            df_train,
            config.model.prompt,
            tokenizer,
            question_ids=question_ids,
            device=DEVICE,
            )
        inputs_embeds = model.get_input_embeddings()(inputs['input_ids'])

        if model_path in T5_MODELS:
            kwargs = {'decoder_input_ids': torch.zeros(inputs['input_ids'].shape[0], 1, dtype=torch.long).to(DEVICE)}
            ind_token = 0
            outputs_key = 'cross_attentions'
        else:
            kwargs = {}
            ind_token = -1
            outputs_key = 'attentions'


        outputs = model(
            inputs_embeds=inputs_embeds,
            attention_mask=inputs['attention_mask'],
            output_attentions=True,
            **kwargs,
            )


        flap = get_flap(outputs, ind_token, outputs_key)
        flap = flap.mean(axis=1) # mean over attention heads

        inputs_grad_yes = get_grad(outputs, inputs_embeds, index_yes, retain_graph=False)
        inputs_grad_norm_yes = get_grad_norm(inputs_grad_yes)
        ixg_norm_yes = get_ixg(inputs_grad_yes, inputs_embeds)

        del outputs
        torch.cuda.empty_cache()

        inputs_intgrad_yes = get_integrated_grad(model, inputs_embeds, inputs['attention_mask'], index_yes, steps=50, kwargs=kwargs)
        inputs_intgrad_norm_yes = get_grad_norm(inputs_intgrad_yes)
        ixig_norm_yes = get_ixg(inputs_intgrad_yes, inputs_embeds)


        save_file(tensor_to_dict(flap, question_ids), 
                  path_save / "flap" / f"batch_num_{batch_num}.safetensors")
        save_file(tensor_to_dict(inputs_grad_norm_yes, question_ids), 
                  path_save / "inputs_grad_norm_yes" / f"batch_num_{batch_num}.safetensors")
        save_file(tensor_to_dict(ixg_norm_yes, question_ids), 
                  path_save / "ixg_norm_yes" / f"batch_num_{batch_num}.safetensors")
        save_file(tensor_to_dict(inputs_intgrad_norm_yes, question_ids), 
                  path_save / "inputs_intgrad_norm_yes" / f"batch_num_{batch_num}.safetensors")
        save_file(tensor_to_dict(ixig_norm_yes, question_ids), 
                  path_save / "ixig_norm_yes" / f"batch_num_{batch_num}.safetensors")

    
    if config.verbose: print("Collating batches...")
    collate_batches(path_save, 'flap')
    collate_batches(path_save, "inputs_grad_norm_yes")
    collate_batches(path_save, "ixg_norm_yes")
    collate_batches(path_save, "inputs_intgrad_norm_yes")
    collate_batches(path_save, "ixig_norm_yes")
    


    
if __name__ == '__main__':
    main()