from src.models.modeling_gpt2 import ExtendedGPT2Config, ExtendedGPT2LMHeadModel
from src.models.modeling_gpt_neox import ExtendedGPTNeoXConfig, ExtendedGPTNeoXForCausalLM
from src.models.common import CausalLMConfig, CausalLM
from src.utils.manual_seed import manual_seed
from src.utils.send_email import send_email

from transformers import AutoTokenizer, DataCollatorForLanguageModeling
from datasets import load_dataset
import torch
from torch.utils.data import DataLoader
import numpy as np
from typing import Union, NamedTuple, Literal
from einops import rearrange
import argparse
from tqdm import tqdm
import matplotlib.pyplot as plt


def stats_wo_outliers(data: np.ndarray, n_iqr=1.5, low=None, high=None, axis=None):
    q1, q3 = np.percentile(data, [25, 75], axis=axis, keepdims=True)
    iqr = q3 - q1
    
    mask = (q1 - n_iqr * iqr <= data) & (data <= q3 + n_iqr * iqr)
    if low is not None:
        mask = mask & (data >= low)
    if high is not None:
        mask = mask & (data <= high)
    data_wo_outliers = np.where(mask, data, np.nan)

    _mean = np.nanmean(data_wo_outliers, axis=axis)
    _std = np.nanstd(data_wo_outliers, axis=axis)
    _max = np.nanmax(data_wo_outliers, axis=axis)

    Stats = NamedTuple("Stats", [("mean", np.ndarray), ("std", np.ndarray), ("max", np.ndarray)])
    return Stats(mean=_mean, std=_std, max=_max)

def get_config_model_tokenizer(model_name: str) -> tuple[CausalLMConfig, CausalLM, AutoTokenizer]:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if model_name == "gpt2":
        _model_name = "gpt2"
        ConfigClass = ExtendedGPT2Config
        ModelClass = ExtendedGPT2LMHeadModel
    elif model_name.split("-")[0] == "pythia":
        _model_name = f"EleutherAI/{model_name}-deduped"
        ConfigClass = ExtendedGPTNeoXConfig
        ModelClass = ExtendedGPTNeoXForCausalLM
    else:
        raise ValueError(f"Invalid model name: {model_name}")
    
    config = ConfigClass.from_pretrained(_model_name)
    model = ModelClass.from_pretrained(_model_name, config=config)
    model.to(device)
    model.eval()

    tokenizer = AutoTokenizer.from_pretrained(_model_name)
    tokenizer.pad_token = tokenizer.eos_token

    head_size = config.n_embd // config.n_head
    n_layer = config.n_layer
    n_head = config.n_head
    attn_pdrop = config.attn_pdrop
    print(f"ModelName: {_model_name}, HeadSize: {head_size}, NumLayer: {n_layer}",
          f"NumHead: {n_head} AttnPdrop: {attn_pdrop}", flush=True)

    return config, model, tokenizer

@torch.no_grad()
def get_queries_and_keys(n_data, batch_size, 
                         model_config: CausalLMConfig, model, tokenizer,
                         dataset_name, dataset_subset, dataset_split=None):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    manual_seed(0)

    dataset_args = (dataset_name, )
    dataset_kwargs = {}
    if dataset_subset is not None:
        dataset_args += (dataset_subset, )
    if dataset_split is not None:
        dataset_kwargs["split"] = dataset_split
    dataset = load_dataset(*dataset_args, **dataset_kwargs)
    dataset = dataset.train_test_split(train_size=n_data)["train"]
    print(f"Dataset: {dataset_name}, Subset: {dataset_subset}, Split: {dataset_split}", flush=True)

    n_layer = model_config.n_layer
    max_length = model_config.max_position_embeddings

    def tokenize_function(examples):
        out = tokenizer(examples["text"], padding="max_length", truncation=True, max_length=max_length)
        return out
    tokenized_datasets = dataset.map(tokenize_function, batched=True)
    columns_to_remove = set(tokenized_datasets.column_names) - {"input_ids", "attention_mask"}
    tokenized_datasets = tokenized_datasets.remove_columns(columns_to_remove)

    collator_fn = DataCollatorForLanguageModeling(tokenizer, mlm=False)
    dataloader = DataLoader(tokenized_datasets, collate_fn=collator_fn, 
                            batch_size=batch_size, shuffle=True)

    layerwise_queries = [[] for _ in range(n_layer)]
    layerwise_keys = [[] for _ in range(n_layer)]
    for batch in tqdm(dataloader, desc="Collecting queries and keys"):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        out = model(input_ids, attention_mask=attention_mask, output_attentions=["query", "key"])
        for i, attn in enumerate(out.attentions):
            queries: torch.Tensor = attn["query"]
            keys   : torch.Tensor = attn["key"]

            _attention_mask = rearrange(attention_mask, "b l -> b 1 l 1")
            _attention_mask = (_attention_mask == 0)
            queries = queries.masked_fill(_attention_mask, float("nan"))
            keys = keys.masked_fill(_attention_mask, float("nan"))

            queries = rearrange(queries, "b h l d -> h (b l) d")
            keys = rearrange(keys, "b h l d -> h (b l) d")
            
            remove_idx = torch.isnan(queries).any(dim=-1).any(dim=0)
            queries = queries[:, ~remove_idx, :]
            keys = keys[:, ~remove_idx, :]

            layerwise_queries[i].append(queries.cpu())
            layerwise_keys[i].append(keys.cpu())

        del input_ids, attention_mask, _attention_mask, out
        torch.cuda.empty_cache()

    layerwise_queries = [torch.cat(layerwise_queries[i], dim=1) for i in range(n_layer)]
    layerwise_keys = [torch.cat(layerwise_keys[i], dim=1) for i in range(n_layer)]

    n_tokens = layerwise_queries[0].size(1)
    print(f"{n_tokens} tokens ({(n_tokens / n_data):.2f} per data) were collected.", flush=True)

    return layerwise_queries, layerwise_keys

@torch.no_grad()
def _compute_dof(lamb: float, n_trial: int, n_sample: int, 
                 queries_and_keys: torch.Tensor, 
                 kernel_type: Literal["exp", "gauss"]="exp"):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_head, n_tokens, head_dim = queries_and_keys.size()

    if kernel_type == "exp":
        indices = torch.randint(0, n_tokens, 
                                size=(n_head, n_trial * n_sample))
    elif kernel_type == "gauss":
        _sq_norm = torch.norm(queries_and_keys, dim=-1) ** 2 / (2 * (head_dim ** 0.5))
        _sq_norm_max = _sq_norm.amax(dim=-1, keepdim=True)
        weights = torch.exp(_sq_norm - _sq_norm_max)
        indices = torch.multinomial(weights, n_trial * n_sample, replacement=True)

    samples = torch.gather(queries_and_keys, 1, 
                           index=indices.unsqueeze(-1).expand(-1, -1, head_dim))
    samples = rearrange(samples, "h (t s) d -> h t s d", s=n_sample)
    samples = samples.to(device)

    if kernel_type == "exp":
        prod = torch.einsum("htnd,htmd->htnm", samples, samples)
        prod = prod / (head_dim ** 0.5)
        prod_max = prod.amax(dim=(-1, -2), keepdim=True)
        kernel = torch.exp(prod - prod_max)
    elif kernel_type == "gauss":
        prod = torch.einsum("htnd,htmd->htnm", samples, samples) / (head_dim ** 0.5)
        samples_sqrd_norm = torch.einsum("htnd,htnd->htn", samples, samples) / (2 * (head_dim ** 0.5))
        x_sqrd_norm = rearrange(samples_sqrd_norm, "h t n -> h t n 1")
        y_sqrd_norm = rearrange(samples_sqrd_norm, "h t n -> h t 1 n")
        kernel = torch.exp(prod - x_sqrd_norm - y_sqrd_norm)

    try:
        kernel_plus_eye = kernel + lamb * torch.eye(n_sample, device=device)
        dof_matrix = torch.linalg.solve(kernel_plus_eye, kernel).cpu().numpy()
        dof = np.trace(dof_matrix, axis1=-1, axis2=-2)
    except torch._C._LinAlgError:
        print("LinAlgError occurred. Setting DoF to n_sample.", flush=True)
        dof = np.ones((n_head, n_trial)) * n_sample
    
    return dof

@torch.no_grad()
def compute_dof(n_trial, n_sample, lambs: list[float], 
                layerwise_queries: list[torch.Tensor], 
                layerwise_keys: list[torch.Tensor],
                kernel_type: Literal["exp", "gauss"]="exp"):
    n_layer = len(layerwise_queries)
    
    layerwise_queries_and_keys = [torch.cat([queries, keys], dim=1) 
                                  for queries, keys in zip(layerwise_queries, layerwise_keys)]
    
    layerwise_headwise_lambwise_avgs, layerwise_headwise_lambwise_stds = [], []
    for lamb in lambs:
        print(f"<-- Lambda = {lamb} -->")
        layerwise_headwise_avgs, layerwise_headwise_stds = [], []
        for layer_idx in range(n_layer):
            dof = _compute_dof(lamb, n_trial, n_sample, layerwise_queries_and_keys[layer_idx], kernel_type)
            print(f"Layer {layer_idx}: {' '.join([f'{d:.2f}' for d in dof.mean(axis=1)])}", flush=True)
            headwise_stats = stats_wo_outliers(dof, axis=1, n_iqr=2, low=0.)
            layerwise_headwise_avgs.append(headwise_stats.mean)
            layerwise_headwise_stds.append(headwise_stats.std)
        layerwise_headwise_avgs = np.stack(layerwise_headwise_avgs, axis=0)
        layerwise_headwise_stds = np.stack(layerwise_headwise_stds, axis=0)
        layerwise_headwise_lambwise_avgs.append(layerwise_headwise_avgs)
        layerwise_headwise_lambwise_stds.append(layerwise_headwise_stds)
    layerwise_headwise_lambwise_avgs = np.stack(layerwise_headwise_lambwise_avgs, axis=-1)
    layerwise_headwise_lambwise_stds = np.stack(layerwise_headwise_lambwise_stds, axis=-1)
    
    return layerwise_headwise_lambwise_avgs, layerwise_headwise_lambwise_stds

def binary_search(y_array: np.ndarray, target_y: float):
    y_array_mean = y_array.mean(axis=tuple(range(y_array.ndim - 1)))
    if y_array_mean[0] < target_y:
        return y_array[..., 0]
    elif y_array_mean[-1] > target_y:
        return y_array[..., -1]
    else:
        l, r = 0, y_array.shape[-1] - 1
        while r - l > 1e-6:
            m = (l + r) / 2
            resulting_y = np.interp(m, np.arange(y_array.shape[-1]), y_array_mean)
            if resulting_y < target_y:
                r = m
            else:
                l = m
        return np.apply_along_axis(lambda x: np.interp(l, np.arange(len(x)), x), axis=-1, arr=y_array)


def main(output_dir: str,
         n_trial, n_sample, n_data, lambs: Union[int, list],
         batch_size,  
         model_name, dataset_name, dataset_subset, dataset_split=None, 
         std_max: int = 0, require_headwise_optimal: bool = False):

    if isinstance(lambs, int):
        lambs = [lambs]
    lambs = sorted(lambs)

    model_config, model, tokenizer = get_config_model_tokenizer(model_name)

    head_size = model_config.n_embd // model_config.n_head
    n_layer = model_config.n_layer
    n_head = model_config.n_head

    layerwise_queries, layerwise_keys = get_queries_and_keys(n_data, batch_size, 
                                                             model_config, model, tokenizer,
                                                             dataset_name, dataset_subset, dataset_split)

    for kernel_type in ["exp", "gauss"]:
        print(f"=== Computing DoF with {kernel_type} kernel. ===")

        layerwise_headwise_lambwise_avgs, layerwise_headwise_lambwise_stds = \
            compute_dof(n_trial, n_sample, lambs, 
                        layerwise_queries, layerwise_keys, kernel_type)

        np.savez(f"{output_dir}/dof.npz", 
                lambs=lambs,
                layerwise_headwise_lambwise_avgs=layerwise_headwise_lambwise_avgs, 
                layerwise_headwise_lambwise_stds=layerwise_headwise_lambwise_stds)

        for i in range(std_max+1):
            for cost_dim in [head_size // 4, head_size // 2, head_size, head_size * 2, head_size * 4]:
                if require_headwise_optimal:
                    headwise_avgs = layerwise_headwise_lambwise_avgs + i * layerwise_headwise_lambwise_stds
                    headwise_avgs = np.maximum.accumulate(headwise_avgs[:, :, ::-1], axis=-1)[:, :, ::-1]
                    headwise_optimal_dim = binary_search(headwise_avgs, cost_dim)

                    print(f"--- Avg + {i} std, Cost {cost_dim}, Headwise ---")
                    print(headwise_optimal_dim, flush=True)
                    print(f"Resulting cost: {headwise_optimal_dim.mean()}", flush=True)

                layerwise_avgs = (layerwise_headwise_lambwise_avgs + i * layerwise_headwise_lambwise_stds).max(axis=1)
                layerwise_avgs = np.maximum.accumulate(layerwise_avgs[:, ::-1], axis=-1)[:, ::-1]
                layerwise_optimal_dim = binary_search(layerwise_avgs, cost_dim)

                print(f"--- Avg + {i} std, Cost {cost_dim}, Layerwise ---")
                print(layerwise_optimal_dim, flush=True)
                print(f"Resulting cost: {layerwise_optimal_dim.mean()}", flush=True)

        fig, axes = plt.subplots(1, 3, figsize=(24, 8), sharex=True, sharey=True, dpi=100)
        axes[0].set_title("Max of the average")
        axes[1].set_title("Max of the average + std")
        axes[2].set_title("Max of the average + 2std")
        for i in range(3):
            for layer_idx, (headwise_avgs, headwise_stds) in enumerate(zip(layerwise_headwise_lambwise_avgs, layerwise_headwise_lambwise_stds)):
                layerwise_max = (headwise_avgs + i * headwise_stds).max(axis=0)
                axes[i].plot(lambs, layerwise_max, label=f"Layer {layer_idx}")
            axes[i].set_xscale("log")
            axes[i].set_xlabel("Lambda")
            axes[i].set_ylabel("DoF")
            axes[i].legend()
        plt.savefig(f"{output_dir}/dof_{kernel_type}_layerwise_max.png")

        fig, axes = plt.subplots(n_layer, n_head, figsize=(3*n_head, 3.2*n_layer), sharex=True, sharey="row")
        for layer_idx, (headwise_avgs, headwise_stds) in enumerate(zip(layerwise_headwise_lambwise_avgs, layerwise_headwise_lambwise_stds)):
            for head_idx, (head_avgs, head_stds) in enumerate(zip(headwise_avgs, headwise_stds)):
                for i in range(3):
                    head_max = head_avgs + i * head_stds
                    label = "Avg" if i == 0 else f"Avg + {i}std"
                    axes[layer_idx, head_idx].plot(lambs, head_max, label=label)
                axes[layer_idx, head_idx].set_title(f"Layer {layer_idx}, Head {head_idx}")
                axes[layer_idx, head_idx].set_xscale("log")
                axes[layer_idx, head_idx].set_xlabel("Lambda")
                axes[layer_idx, head_idx].set_ylabel("DoF")
                axes[layer_idx, head_idx].legend()
        plt.savefig(f"{output_dir}/dof_{kernel_type}_layerwise_headwise.png")

    queries = torch.stack(layerwise_queries, dim=0)
    query_norms = torch.norm(queries, dim=-1).numpy()
    fig, axes = plt.subplots(n_layer, n_head, figsize=(3*n_head, 3.2*n_layer), sharex=True, sharey=True)
    for layer_idx in range(n_layer):
        for head_idx in range(n_head):
            axes[layer_idx, head_idx].hist(query_norms[layer_idx, head_idx], bins=20)
            axes[layer_idx, head_idx].set_title(f"Layer {layer_idx}, Head {head_idx}")
            axes[layer_idx, head_idx].set_xlabel("Norm")
            axes[layer_idx, head_idx].set_ylabel("Frequency")
    plt.savefig(f"{output_dir}/query_norms.png")

    keys = torch.stack(layerwise_keys, dim=0) 
    key_norms = torch.norm(keys, dim=-1).numpy() 
    fig, axes = plt.subplots(n_layer, n_head, figsize=(3*n_head, 3.2*n_layer), sharex=True, sharey=True)
    for layer_idx in range(n_layer):
        for head_idx in range(n_head):
            axes[layer_idx, head_idx].hist(key_norms[layer_idx, head_idx], bins=20)
            axes[layer_idx, head_idx].set_title(f"Layer {layer_idx}, Head {head_idx}")
            axes[layer_idx, head_idx].set_xlabel("Norm")
            axes[layer_idx, head_idx].set_ylabel("Frequency")
    plt.savefig(f"{output_dir}/key_norms.png")
        
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--n_trial", type=int, default=10)
    parser.add_argument("--n_sample", type=int, default=1000)
    parser.add_argument("--n_data", type=int, default=1000)
    parser.add_argument("--lambs", type=float, nargs="+", default=[1e-3, 1e-2, 1e-1, 1])
    parser.add_argument("--batch_size", type=int, default=64)
    parser.add_argument("--model_name", type=str, default="gpt2")
    parser.add_argument("--dataset_name", type=str, default="wikipedia")
    parser.add_argument("--dataset_subset", type=str, default="20220301.en")
    parser.add_argument("--dataset_split", type=str, default="train")
    parser.add_argument("--output_dir", type=str, default="./logs/observe-original-model")
    args = parser.parse_args()

    try:
        main(args.output_dir, args.n_trial, args.n_sample, args.n_data, args.lambs, args.batch_size, 
             args.model_name, args.dataset_name, args.dataset_subset, args.dataset_split)
    except Exception as e:
        send_email("[ERROR] observe_original_model.py", 
                   "An error occurred in observe_original_model.py. "
                   + "Error message: \n" + str(e))
        raise e
    
    send_email("[DONE] observe_original_model.py", "The observe_original_model.py has been completed.")