import os
from copy import deepcopy
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from rome.layer_stats import layer_stats
from util import nethook
from util.generate import generate_fast
from util.globals import *

from .compute_ks import compute_ks
from .compute_z import compute_z, get_module_input_output_at_words, find_fact_lookup_idx
from .memit_hparams import MEMITHyperParams

# Cache variable(s)
CONTEXT_TEMPLATES_CACHE = None
COV_CACHE = {}




def apply_memit_forgetting_to_model(
    model: AutoModelForCausalLM,
    tok: AutoTokenizer,
    requests: List[Dict],
    hparams: MEMITHyperParams,
    edit_round = None,
    copy=False,
    return_orig_weights=False,
    cache_template: Optional[str] = None,
    cache_c=None,
    total_upd_matrix: Optional[Dict[int, torch.Tensor]] = None,
) -> Tuple[AutoModelForCausalLM, Any, Dict[int, torch.Tensor]]:
    """
    Unified MEMIT function with cumulative update tracking.
    """
    if copy:
        model = deepcopy(model)

    if total_upd_matrix is None:
        total_upd_matrix = {}

    weights_copy = {}
    deltas = {}

    requests = deepcopy(requests)
    for i, request in enumerate(requests):
        if request["target_new"]["str"][0] != " ":
            requests[i]["target_new"]["str"] = " " + request["target_new"]["str"]
    for request in requests[:10]:
        print(
            f"MEMIT request sample: "
            f"[{request['prompt'].format(request['subject'])}] -> [{request['target_new']['str']}]"
        )

    weights = {
        f"{hparams.rewrite_module_tmp.format(layer)}.weight": nethook.get_parameter(
            model, f"{hparams.rewrite_module_tmp.format(layer)}.weight"
        )
        for layer in hparams.layers
    }
    weights_copy = {k: v.detach().clone() for k, v in weights.items()}


    context_templates = get_context_templates(model, tok)
    z_layer = hparams.layers[-1]
    z_list = []

    for request in requests:
        # Retrieve k/v pair if already stored in cache
        cache_fname = (
            Path(
                str(cache_template).format(
                    z_layer, hparams.clamp_norm_factor, request["case_id"]
                )
            )
            if cache_template is not None
            else None
        )
        data_loaded = False
        if (
            cache_fname is not None  # Require cache template
            and cache_fname.exists()  # Cache file must exist
        ):
            try:
                data = np.load(cache_fname)
                z_list.append(torch.from_numpy(data["v_star"]).to("cuda"))
                data_loaded = True
            except Exception as e:
                print(f"Error reading cache file due to {e}. Recomputing...")

        # Compute k/v pair if not loaded from cache
        if not data_loaded:
            cur_z = compute_z(
                model,
                tok,
                request,
                hparams,
                z_layer,
                context_templates,
            )

            z_list.append(cur_z)

            if cache_fname is not None:
                cache_fname.parent.mkdir(exist_ok=True, parents=True)
                np.savez(
                    cache_fname,
                    **{
                        "v_star": cur_z.detach().cpu().numpy(),
                    },
                )
                print(f"Cached k/v pair at {cache_fname}")
    zs = torch.stack(z_list, dim=1)

    for i, layer in enumerate(hparams.layers):
        print(f"\n>> LAYER {layer}")

        layer_ks = compute_ks(model, tok, requests, hparams, layer, context_templates).T
        print(f"{layer_ks.shape[1]} K/V pairs")

        cur_zs = get_module_input_output_at_words(
            model, tok, z_layer,
            context_templates=[r["prompt"] for r in requests],
            words=[r["subject"] for r in requests],
            module_template=hparams.layer_module_tmp,
            fact_token_strategy=hparams.fact_token,
        )[1].T

        targets = zs - cur_zs
        print("z error:", torch.linalg.norm(targets, dim=0).mean())
        repeat_factor = (layer_ks.size(1) // targets.size(1))
        targets = targets.repeat_interleave(repeat_factor, dim=1)

        # Load covariance matrix
        force_recompute = False
        # force_recompute = layer != hparams.layers[0]
        cov = get_cov(
            model,
            tok,
            hparams.rewrite_module_tmp.format(layer),
            hparams.mom2_dataset,
            hparams.mom2_n_samples
            if not force_recompute
            else hparams.mom2_n_samples // 10,
            hparams.mom2_dtype,
            force_recompute=force_recompute,
        )

        layer_ks, targets = layer_ks.double(), targets.double()
        resid = targets / (len(hparams.layers) - i)

        # print(total_upd_matrix)
        for k__, v__ in total_upd_matrix.items():
            print(f'total upd matrix layer idx {k__} norm: {torch.linalg.norm(v__)}')


        # if hparams.cache_c:
        #     if hparams.normalization:
        #         lhs = (
        #             hparams.mom2_update_weight * cov.double()
        #             + (0 if edit_round == 0 else (1. / edit_round) * cache_c[i, :, :].cuda().double())
        #             + layer_ks @ layer_ks.T)
        #     else:
        #         lhs = (
        #             hparams.mom2_update_weight * cov.double() +
        #             cache_c[i, :, :].cuda().double() +
        #             layer_ks @ layer_ks.T
        #         )
        # else:
        lhs = (hparams.mom2_update_weight * cov.double() + layer_ks @ layer_ks.T)

        # lhs = (
        #     hparams.mom2_update_weight * cov.double()
        #     + (0 if edit_round == 0 else (1. / edit_round) * cache_c[i, :, :].cuda().double())
        #     + layer_ks @ layer_ks.T)
        
        rhs = resid @ layer_ks.T

        # ← subtract effect of previous cumulative update
        if total_upd_matrix is not None and i in total_upd_matrix:
            # rhs -= (hparams.mom2_update_weight * total_upd_matrix[i].cuda().double() @ cov.double()+ total_upd_matrix[i].cuda().double()@ cache_c[i, :, :].cuda().double())
            if total_upd_matrix[i].shape[1] == cov.shape[0]:
                rhs -= hparams.mom2_update_weight * total_upd_matrix[i].cuda().double() @ cov.double()
            else:
                rhs -= hparams.mom2_update_weight * total_upd_matrix[i].T.cuda().double() @ cov.double()
        
        adj_k = torch.linalg.solve(lhs, rhs.T)
        upd_matrix = adj_k

        weight_name = f"{hparams.rewrite_module_tmp.format(layer)}.weight"
        upd_matrix = upd_matrix_match_shape(upd_matrix, weights[weight_name].shape)

        # if return_orig_weights and weight_name not in weights_copy:
        #     weights_copy[weight_name] = weights[weight_name].detach().clone()

        with torch.no_grad():
            weights[weight_name][...] = weights_copy[weight_name] + upd_matrix.float()
            # deltas[weight_name] = (adj_k.detach().cpu(), resid.detach().cpu())

        # Accumulate numeric upd_matrix
        if i in total_upd_matrix:
            total_upd_matrix[i] += upd_matrix.detach().cpu()
        else:
            total_upd_matrix[i] = upd_matrix.detach().cpu()

        # Clear GPU
        # cov.cpu()
        for x in [layer_ks, cur_zs, targets, upd_matrix]:
            x.cpu()
            del x
        torch.cuda.empty_cache()

    # Update cache_c
    # for i, layer in enumerate(hparams.layers):
    #     layer_ks = compute_ks(model, tok, requests, hparams, layer, context_templates).T
    #     cache_c[i, :, :] += layer_ks.cpu() @ layer_ks.cpu().T

    # Restore model
    # with torch.no_grad():
    #     for k, v in weights.items():
    #         v[...] = weights_copy[k]

    print(f"MEMIT done. Layers updated: {list(total_upd_matrix.keys())}")
    return model, cache_c, total_upd_matrix





# def apply_memit_cumulative_to_model(
#     model: AutoModelForCausalLM,
#     tok: AutoTokenizer,
#     requests: List[Dict],
#     hparams: MEMITHyperParams,
#     copy=False,
#     return_orig_weights=False,
#     cache_template: Optional[str] = None,
#     cache_c=None,
#     total_upd_matrix: Optional[Dict[str, torch.Tensor]] = None,
# ) -> Tuple[AutoModelForCausalLM, Any, Dict[str, torch.Tensor]]:
#     """
#     Apply MEMIT and accumulate only the total update matrices (not raw K/V).
#     """
#     if copy:
#         model = deepcopy(model)

#     weights_copy = {}

#     if total_upd_matrix is None:
#         total_upd_matrix = {}

#     deltas, cache_c = execute_memit(model, tok, requests, hparams, cache_template=cache_template, cache_c = cache_c, total_upd_matrix=total_upd_matrix)

  
#     with torch.no_grad():
#         for w_name, (key_mat, val_mat) in deltas.items():
#             key_mat, val_mat = key_mat.to("cuda"), val_mat.to("cuda")
#             upd_matrix = key_mat @ val_mat.T
#             w = nethook.get_parameter(model, w_name)
#             upd_matrix = upd_matrix_match_shape(upd_matrix, w.shape)

#             # Save original weights if needed
#             if return_orig_weights and w_name not in weights_copy:
#                 weights_copy[w_name] = w.detach().clone()

#             # Apply update
#             w[...] += upd_matrix.float()

#             # === Accumulate numeric upd_matrix ===
#             if w_name in total_upd_matrix:
#                 total_upd_matrix[w_name] += upd_matrix.detach().cpu()
#             else:
#                 total_upd_matrix[w_name] = upd_matrix.detach().cpu()

#     print(f"New weights inserted into {list(deltas.keys())}")
#     return model, cache_c, total_upd_matrix





# def execute_memit(
#     model: AutoModelForCausalLM,
#     tok: AutoTokenizer,
#     requests: List[Dict],
#     hparams: MEMITHyperParams,
#     cache_template: Optional[str] = None,
#     cache_c = None,
#     total_upd_matrix = None,
# ) -> Dict[str, Tuple[torch.Tensor]]:
#     """
#     Executes the MEMIT update algorithm for the specified update at the specified layer
#     Invariant: model at beginning of function == model at end of function
#     """

#     deltas = {}

#     # Update target and print info
#     requests = deepcopy(requests)
#     for i, request in enumerate(requests):
#         if request["target_new"]["str"][0] != " ":
#             # Space required for correct tokenization
#             requests[i]["target_new"]["str"] = " " + request["target_new"]["str"]
#     for request in requests[:10]:
#         print(
#             f"MEMIT request sample: "
#             f"[{request['prompt'].format(request['subject'])}] -> [{request['target_new']['str']}]"
#         )

#     # Retrieve weights that user desires to change
#     weights = {
#         f"{hparams.rewrite_module_tmp.format(layer)}.weight": nethook.get_parameter(
#             model, f"{hparams.rewrite_module_tmp.format(layer)}.weight"
#         )
#         for layer in hparams.layers
#     }
#     # Save old weights for future restoration
#     weights_copy = {k: v.detach().clone() for k, v in weights.items()}

#     # Compute z for final layer
#     context_templates = get_context_templates(model, tok)
#     z_layer = hparams.layers[-1]
#     z_list = []

#     for request in requests:
#         # Retrieve k/v pair if already stored in cache
#         cache_fname = (
#             Path(
#                 str(cache_template).format(
#                     z_layer, hparams.clamp_norm_factor, request["case_id"]
#                 )
#             )
#             if cache_template is not None
#             else None
#         )
#         data_loaded = False
#         if (
#             cache_fname is not None  # Require cache template
#             and cache_fname.exists()  # Cache file must exist
#         ):
#             try:
#                 data = np.load(cache_fname)
#                 z_list.append(torch.from_numpy(data["v_star"]).to("cuda"))
#                 data_loaded = True
#             except Exception as e:
#                 print(f"Error reading cache file due to {e}. Recomputing...")

#         # Compute k/v pair if not loaded from cache
#         if not data_loaded:
#             cur_z = compute_z(
#                 model,
#                 tok,
#                 request,
#                 hparams,
#                 z_layer,
#                 context_templates,
#             )

#             z_list.append(cur_z)

#             if cache_fname is not None:
#                 cache_fname.parent.mkdir(exist_ok=True, parents=True)
#                 np.savez(
#                     cache_fname,
#                     **{
#                         "v_star": cur_z.detach().cpu().numpy(),
#                     },
#                 )
#                 print(f"Cached k/v pair at {cache_fname}")
#     zs = torch.stack(z_list, dim=1)

#     # Insert
#     for i, layer in enumerate(hparams.layers):
#         print(f"\n\nLAYER {layer}\n")

#         # Get current model activations
#         layer_ks = compute_ks(model, tok, requests, hparams, layer, context_templates).T
#         print(f"Writing {layer_ks.size(1)} key/value pair(s) into layer {layer}")

#         # Compute residual error
#         cur_zs = get_module_input_output_at_words(
#             model,
#             tok,
#             z_layer,
#             context_templates=[request["prompt"] for request in requests],
#             words=[request["subject"] for request in requests],
#             module_template=hparams.layer_module_tmp,
#             fact_token_strategy=hparams.fact_token,
#         )[1].T
#         targets = zs - cur_zs
#         print("z error", torch.linalg.norm(targets, dim=0).mean())

#         repeat_factor = (layer_ks.size(1) // targets.size(1))
#         targets = targets.repeat_interleave(repeat_factor, dim=1)

#         # Load covariance matrix
#         force_recompute = False
#         # force_recompute = layer != hparams.layers[0]
#         cov = get_cov(
#             model,
#             tok,
#             hparams.rewrite_module_tmp.format(layer),
#             hparams.mom2_dataset,
#             hparams.mom2_n_samples
#             if not force_recompute
#             else hparams.mom2_n_samples // 10,
#             hparams.mom2_dtype,
#             force_recompute=force_recompute,
#         )

#         # Compute update in double precision
#         layer_ks, targets = (
#             layer_ks.double(),
#             targets.double(),
#         )

#         print(cache_c)

#         resid = targets / (len(hparams.layers) - i)   # Distribute residual across layers

#         adj_k = torch.linalg.solve(
#             hparams.mom2_update_weight * cov.double() + cache_c[i,:,:].cuda().double() + layer_ks @ layer_ks.T,
#             resid @ layer_ks - hparams.mom2_update_weight * total_upd_matrix @ cov.double(),
#         )

#         upd_matrix = adj_k

#         # Adjust update matrix shape
#         weight_name = f"{hparams.rewrite_module_tmp.format(layer)}.weight"
#         upd_matrix = upd_matrix_match_shape(upd_matrix, weights[weight_name].shape)

#         print("orig norm", torch.linalg.norm(weights[weight_name]))
#         print("upd norm", torch.linalg.norm(upd_matrix))

#         # Update model weights and record desired changes in `delta` variable
#         with torch.no_grad():
#             weights[weight_name][...] = weights_copy[weight_name] + upd_matrix.float()
#             deltas[weight_name] = (
#                 adj_k.detach().cpu(),
#                 resid.detach().cpu(),
#             )

#         # Clear GPU memory
#         cov.cpu()
#         for x in [layer_ks, cur_zs, targets]:
#             x.cpu()
#             del x
#         torch.cuda.empty_cache()
#     for i, layer in enumerate(hparams.layers):
#         layer_ks = compute_ks(model, tok, requests, hparams, layer, context_templates).T
#         cache_c[i,:,:] += layer_ks.cpu() @ layer_ks.cpu().T
#     # Restore state of original model
#     with torch.no_grad():
#         for k, v in weights.items():
#             v[...] = weights_copy[k]

#     print(f"Deltas successfully computed for {list(weights.keys())}")

#     return deltas, cache_c


def get_cov(
    model: AutoModelForCausalLM,
    tok: AutoTokenizer,
    layer_name: str,
    mom2_dataset: str,
    mom2_n_samples: str,
    mom2_dtype: str,
    inv: bool = False,
    force_recompute: bool = False,
) -> torch.Tensor:
    """
    Retrieves covariance statistics, then computes the algebraic inverse.
    Caches result for future use.
    """

    model_name = model.config._name_or_path.replace("/", "_")
    key = (model_name, layer_name)

    print(f"Retrieving covariance statistics for {model_name} @ {layer_name}.")
    if key not in COV_CACHE or force_recompute:
        stat = layer_stats(
            model,
            tok,
            layer_name,
            STATS_DIR,
            mom2_dataset,
            to_collect=["mom2"],
            sample_size=mom2_n_samples,
            precision=mom2_dtype,
            force_recompute=force_recompute,
        )
        COV_CACHE[key] = stat.mom2.moment().float().to("cpu")

    return (
        torch.inverse(COV_CACHE[key].to("cuda")) if inv else COV_CACHE[key].to("cuda")
    )


def upd_matrix_match_shape(matrix: torch.Tensor, shape: torch.Size) -> torch.Tensor:
    """
    GPT-2 and GPT-J have transposed weight representations.
    Returns a matrix that matches the desired shape, else raises a ValueError
    """

    if matrix.shape == shape:
        return matrix
    elif matrix.T.shape == shape:
        return matrix.T
    else:
        raise ValueError(
            "Update matrix computed by MEMIT does not match original weight shape. "
            "Check for bugs in the code?"
        )


def get_context_templates(model, tok):
    global CONTEXT_TEMPLATES_CACHE

    if CONTEXT_TEMPLATES_CACHE is None:
        CONTEXT_TEMPLATES_CACHE = [["{}"]] + [
            [
                f.replace("{", " ").replace("}", " ") + ". {}"
                for f in generate_fast(
                    model,
                    tok,
                    ["The", "Therefore", "Because", "I", "You"],
                    n_gen_per_prompt=n_gen // 5,
                    max_out_len=length,
                )
            ]
            for length, n_gen in [(10, 5)]  # Be careful about changing this.
        ]
        print(f"Cached context templates {CONTEXT_TEMPLATES_CACHE}")

    return CONTEXT_TEMPLATES_CACHE