import os
from copy import deepcopy
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from ..rome.layer_stats import layer_stats
from ...util import nethook
from ...util.generate import generate_fast
from ...util.globals import *

from .compute_ks import compute_ks
from .compute_z import compute_z, get_module_input_output_at_words, find_fact_lookup_idx
from .sue_hparams import SUEHyperParams
import tqdm
import torch.nn.functional as F
# Cache variable(s)
CONTEXT_TEMPLATES_CACHE = None
COV_CACHE = {}

import numpy as np


# @ torch.no_grad()
# def proximal_gradient_descent(X, Y, W, c, lambda_=0.1, learning_rate=0.0025, max_iter=200, tol=0.1):
#     """Proximal Gradient Descent for L1-regularized least squares problem."""
#     # Initialize
#     # hat_W = torch.zeros(W.shape).to(W.device)
#     hat_W = deepcopy(W)
#     # breakpoint()
#     # X_T_X = X @ X.T #[11008, 11008]
#     # X_T_Y = Y @ X.T #[4096， 11008]
#     # c_T_c = c.T @ c #[11008, 11008]
#     # X [11008, 1]
#     # Y [4096, 1]
#     # W [4096, 11008]
#     # c [11008, 11008]
#     # [4096, 11008]

#     for _ in tqdm.tqdm(range(max_iter)):
#         # hat_W_old = deepcopy(hat_W)
#         # Gradient step
#         # grad = 2 * (hat_W @ X_T_X - X_T_Y) + 2 * (hat_W @ c_T_c - (W @ c) @ c.T)
#         grad = (hat_W-W) @ c * 50 + (hat_W @ X - Y) @ X.T
#         # grad = 2 * (X_T_X @ hat_W - X_T_Y) + 2 * (c_T_c @ hat_W - c.T @ (c @ W))
#         hat_W = hat_W - learning_rate * grad
#         # Proximal step
#         hat_W = W - torch.sign(W - hat_W) * torch.maximum(torch.abs(W - hat_W) - lambda_, torch.tensor(0.0))
#         # print(np.linalg.norm(hat_W.cpu().numpy() - hat_W_old.cpu().numpy()))
#         print("item 1:\n", hat_W @ c - W, "item2: \n", torch.norm(hat_W @ X - Y), "item3: \n", torch.norm(hat_W-W, p=1))
#         # if np.linalg.norm(hat_W.cpu().numpy() - hat_W_old.cpu().numpy()) < tol:
#             # break
        

#     return hat_W

def proximal_gradient_descent(X, Y, W, c, lambda_=100000, learning_rate=0.0005, max_iter=50, k=1000):
    c = c * 50
    # hat_W = deepcopy(W)
    delta_W = torch.zeros_like(W, requires_grad=True)
    # torch.nn.init.normal_(delta_W, std=0.1)
    # hat_W.requires_grad = True
    opt = torch.optim.Adam([delta_W], lr=learning_rate)
    loss = torch.nn.MSELoss()
    l1_loss = torch.nn.L1Loss()
    relu = torch.nn.ReLU()
    for it in tqdm.tqdm(range(max_iter)):
        # sparse_delta_W = relu(delta_W)
        opt.zero_grad()
        # loss = torch.norm(hat_W @ X - Y) ** 2 \
        loss1 = loss((W+delta_W) @ X, Y)
        loss2 = loss(c @ delta_W.T ,torch.zeros(c.shape[0], delta_W.shape[0]).to(c.device))
        loss3 = l1_loss(delta_W + W, W) * lambda_
        output = loss1 + loss2# + loss3
        output.backward()
        opt.step()
        # breakpoint()
        with torch.no_grad():
            delta_W.data = torch.where((delta_W).abs() < torch.topk(torch.abs(delta_W.flatten()), k=k)[0].min(), 0, delta_W)
           
        print(loss1, loss2, loss3)
        print((torch.abs(delta_W) != 0).sum())
    return W+delta_W


def apply_sue_to_model(
    model: AutoModelForCausalLM,
    tok: AutoTokenizer,
    requests: List[Dict],
    hparams: SUEHyperParams,
    copy=False,
    return_orig_weights=False,
    cache_template: Optional[str] = None,
    keep_original_weight=False,
    **kwargs
) -> Tuple[AutoModelForCausalLM, Dict[str, Any]]:
    """
    Returns a model with the desired changes.
    :param copy: If true, will preserve the original model while creating a new one to edit.
        Note that you are responsible for deallocating the new model's memory to avoid leaks.
    :return: (1) the updated model, (2) an original copy of the weights that changed
    """

    weights_copy = {}
    if copy:
        model = deepcopy(model)

    new_weight = execute_sue(model, tok, requests, hparams, cache_template=cache_template)

    with torch.no_grad():
        for w_name, (weight_mat) in new_weight.items():
            # key_mat, val_mat = key_mat.to(f"cuda:{hparams.device}"), val_mat.to(f"cuda:{hparams.device}")
            # upd_matrix = key_mat @ val_mat.T
            weight_mat = weight_mat[0].to(f"cuda:{hparams.device}")
            w = nethook.get_parameter(model, w_name)
            weight_mat = upd_matrix_match_shape(weight_mat, w.shape)

            if return_orig_weights and w_name not in weights_copy:
                weights_copy[w_name] = w.detach().clone()
            w[...] = weight_mat

    print(f"New weights successfully inserted into {list(new_weight.keys())}")

    if not keep_original_weight:
        weights_copy = {}

    return model, weights_copy


def execute_sue(
    model: AutoModelForCausalLM,
    tok: AutoTokenizer,
    requests: List[Dict],
    hparams: SUEHyperParams,
    cache_template: Optional[str] = None,
) -> Dict[str, Tuple[torch.Tensor]]:
    """
    Executes the MEMIT update algorithm for the specified update at the specified layer
    Invariant: model at beginning of function == model at end of function
    """

    deltas = {}

    # Update target and print info
    requests = deepcopy(requests)
    for i, request in enumerate(requests):
        if request["target_new"][0] != " ":
            # Space required for correct tokenization
            requests[i]["target_new"] = " " + request["target_new"]

        if '{}' not in request['prompt']:
            assert request['subject'] in request['prompt'] or \
                   print(f"Subject:{request['subject']} do not exist in prompt: {request['prompt']}")

            requests[i]['prompt'] = requests[i]['prompt'].replace(requests[i]['subject'], '{}')

    # for request in requests[:10]:
    #     print(
    #         f"MEMIT request sample: "
    #         f"[{request['prompt'].format(request['subject'])}] -> [{request['target_new']}]"
    #     )

    # Retrieve weights that user desires to change
    weights = {
        f"{hparams.rewrite_module_tmp.format(layer)}.weight": nethook.get_parameter(
            model, f"{hparams.rewrite_module_tmp.format(layer)}.weight"
        )
        for layer in range(model.config.num_key_value_heads)
    }
    # Save old weights for future restoration
    weights_copy = {k: v.detach().clone() for k, v in weights.items()}

    # Compute z for final layer
    context_templates = get_context_templates(model, tok)
    # z_layer = hparams.layers[-1]
    z_list = []
    zs_list = []
    ks_list = []
    c_list = []
    print("Computing ks and zs for each layer.")
    all_layers = hparams.layers
    for layer in tqdm.tqdm(all_layers): #model.config.num_hidden_layers range(28)
        layer_ks = compute_ks(model, tok, requests, hparams, layer, context_templates).T
        for r in requests:
            cur_z = compute_z(model, tok, r, hparams, layer, context_templates)
            z_list.append(cur_z)
        zs = torch.stack(z_list, dim=1)
        zs_list.append(zs)
        ks_list.append(layer_ks)
        force_recompute = False
        # force_recompute = layer != hparams.layers[0]
        cov = get_cov(model,tok,hparams.rewrite_module_tmp.format(layer),hparams.mom2_dataset,hparams.mom2_n_samples if not force_recompute else hparams.mom2_n_samples // 10, hparams.mom2_dtype, force_recompute=force_recompute,hparams=hparams)

        c_list.append(cov)
    
    # stacked_zs = torch.stack(zs_list, dim=-1)
    # stacked_ks = torch.stack(ks_list, dim=-1)
    # stacked_c = torch.stack(c_list, dim=-1)
    # normalized_stacked_zs = F.normalize(stacked_zs, dim=-1)
    # normalized_stacked_ks = F.normalize(stacked_ks, dim=-1)
    # normalized_stacked_c = F.normalize(stacked_c, dim=-1)
    # zs_list = torch.unbind(normalized_stacked_zs, dim=-1)
    # ks_list = torch.unbind(normalized_stacked_ks, dim=-1)
    # c_list = torch.unbind(normalized_stacked_c, dim=-1)
    # breakpoint()
    # for request in requests:
    #     # Retrieve k/v pair if already stored in cache
    #     cache_fname = (
    #         Path(
    #             str(cache_template).format(
    #                 z_layer, hparams.clamp_norm_factor, request["case_id"]
    #             )
    #         )
    #         if cache_template is not None
    #         else None
    #     )
    #     data_loaded = False
    #     if (
    #         cache_fname is not None  # Require cache template
    #         and cache_fname.exists()  # Cache file must exist
    #     ):
    #         try:
    #             data = np.load(cache_fname)
    #             z_list.append(torch.from_numpy(data["v_star"]).to(f"cuda:{hparams.device}"))
    #             data_loaded = True
    #         except Exception as e:
    #             print(f"Error reading cache file due to {e}. Recomputing...")

    #     # Compute k/v pair if not loaded from cache
    #     if not data_loaded:
    #         cur_z = compute_z(
    #             model,
    #             tok,
    #             request,
    #             hparams,
    #             z_layer,
    #             context_templates,
    #         )

    #         z_list.append(cur_z)

    #         if cache_fname is not None:
    #             cache_fname.parent.mkdir(exist_ok=True, parents=True)
    #             np.savez(
    #                 cache_fname,
    #                 **{
    #                     "v_star": cur_z.detach().cpu().numpy(),
    #                 },
    #             )
    #             print(f"Cached k/v pair at {cache_fname}")
    # zs = torch.stack(z_list, dim=1)


    # breakpoint()
    # Insert
    # for i, layer in enumerate(hparams.layers):
    
    # NOTE: change to z_layer model.config.num_key_value_heads
    for layer, layer_ks, zs, cov in zip(all_layers, ks_list, zs_list, c_list): #range(model.config.num_hidden_layers)
        print(f"\n\nLAYER {layer}\n")

        # Get current model activations
        # layer_ks = compute_ks(model, tok, requests, hparams, layer, context_templates).T
        # print(f"Writing {layer_ks.size(1)} key/value pair(s) into layer {layer}")
        # z_list = []
        # for r in requests:
        #     cur_z = compute_z(model, tok, r, hparams, layer, context_templates)
        #     z_list.append(cur_z)
        # zs = torch.stack(z_list, dim=1)

        # Compute residual error
        # cur_zs = get_module_input_output_at_words(
        #     model,
        #     tok,
        #     z_layer,
        #     context_templates=[request["prompt"] for request in requests],
        #     words=[request["subject"] for request in requests],
        #     module_template=hparams.layer_module_tmp,
        #     fact_token_strategy=hparams.fact_token,
        #     track='out'
        # ).T
        # targets = zs - cur_zs
        # print("z error", torch.linalg.norm(targets, dim=0).mean())

        # repeat_factor = (layer_ks.size(0) // zs.size(0))
        # zs = zs.repeat_interleave(repeat_factor, dim=0)

        # Load covariance matrix
        # force_recompute = False
        # # force_recompute = layer != hparams.layers[0]
        # cov = get_cov(model,tok,hparams.rewrite_module_tmp.format(layer),hparams.mom2_dataset,hparams.mom2_n_samples if not force_recompute else hparams.mom2_n_samples // 10, hparams.mom2_dtype, force_recompute=force_recompute,hparams=hparams)

        # Compute update in double precision
        # layer_ks, targets = (
        #     layer_ks.double(),
        #     targets.double(),
        # )
        # layer_ks = layer_ks.double(),
           
        
        weight_name = f"{hparams.rewrite_module_tmp.format(layer)}.weight"
        # layer_ks [11008, 1]
        # zs [4096, 1]
        # weight [4096, 11008]
        # c [11008, 11008]
        upd_weight = proximal_gradient_descent(layer_ks, zs, weights_copy[weight_name], cov)

    #     adj_k = torch.linalg.solve(
    #         hparams.mom2_update_weight * cov.double() + layer_ks @ layer_ks.T,
    #         layer_ks,
    #     )
    #     resid = targets / (len(hparams.layers) - i)  # Distribute residual across layers
    #     upd_matrix = resid @ adj_k.T

    #     # Adjust update matrix shape
    #     weight_name = f"{hparams.rewrite_module_tmp.format(layer)}.weight"
    #     upd_matrix = upd_matrix_match_shape(upd_matrix, weights[weight_name].shape)

        print("orig norm", torch.linalg.norm(weights[weight_name]))
        print("upd norm", torch.linalg.norm(upd_weight))

        # Update model weights and record desired changes in `delta` variable
        with torch.no_grad():
            weights[weight_name][...] = upd_weight
            # weights[weight_name][...] = weights_copy[weight_name] + upd_matrix.float()
            deltas[weight_name] = (
                upd_weight.detach().cpu(),
            )

        # Clear GPU memory
        cov.cpu()
        for x in [layer_ks, zs]:
            x.cpu()
            del x
        torch.cuda.empty_cache()

    # Restore state of original model
    with torch.no_grad():
        for k, v in weights.items():
            v[...] = weights_copy[k]

    print(f"Deltas successfully computed for {list(weights.keys())}")

    return deltas


def get_cov(
    model: AutoModelForCausalLM,
    tok: AutoTokenizer,
    layer_name: str,
    mom2_dataset: str,
    mom2_n_samples: str,
    mom2_dtype: str,
    inv: bool = False,
    force_recompute: bool = False,
    hparams=None,
) -> torch.Tensor:
    """
    Retrieves covariance statistics, then computes the algebraic inverse.
    Caches result for future use.
    """

    model_name = model.config._name_or_path.replace("/", "_")
    key = (model_name, layer_name)

    print(f"Retrieving covariance statistics for {model_name} @ {layer_name}.")
    if key not in COV_CACHE or force_recompute:
        stat = layer_stats(
            model,
            tok,
            layer_name,
            hparams.stats_dir,
            mom2_dataset,
            to_collect=["mom2"],
            sample_size=mom2_n_samples,
            precision=mom2_dtype,
            hparams=hparams,
            force_recompute=force_recompute,
        )
        COV_CACHE[key] = stat.mom2.moment().float().to("cpu")

    return (
        torch.inverse(COV_CACHE[key].to(f"cuda:{hparams.device}")) if inv else COV_CACHE[key].to(f"cuda:{hparams.device}")
    )


def upd_matrix_match_shape(matrix: torch.Tensor, shape: torch.Size) -> torch.Tensor:
    """
    GPT-2 and GPT-J have transposed weight representations.
    Returns a matrix that matches the desired shape, else raises a ValueError
    """

    if matrix.shape == shape:
        return matrix
    elif matrix.T.shape == shape:
        return matrix.T
    else:
        raise ValueError(
            "Update matrix computed by MEMIT does not match original weight shape. "
            "Check for bugs in the code?"
        )


def get_context_templates(model, tok):
    global CONTEXT_TEMPLATES_CACHE

    if CONTEXT_TEMPLATES_CACHE is None:
        CONTEXT_TEMPLATES_CACHE = [["{}"]] + [
            [
                f.replace("{", " ").replace("}", " ") + ". {}"
                for f in generate_fast(
                    model,
                    tok,
                    ["The", "Therefore", "Because", "I", "You"],
                    n_gen_per_prompt=n_gen // 5,
                    max_out_len=length,
                )
            ]
            for length, n_gen in [(10, 5)]  # Be careful about changing this.
        ]
        print(f"Cached context templates {CONTEXT_TEMPLATES_CACHE}")

    return CONTEXT_TEMPLATES_CACHE
