import json

import numpy as np
import pandas as pd
import torch
from nnsight import LanguageModel
from simple_parsing import ArgumentParser
from tqdm import tqdm
from sparsify import SparseCoder


def get_predicted_latents_topk(sentence_data, latents, active_latents_tensor, autoencoder_latents,k=128):
    
    # select the latents from the sentence data
    latents_in = sentence_data["feature"].isin(latents)

    # get the predicted activations for the wanted latents
    wanted_predicted_activations = sentence_data[latents_in]["activation"].tolist()
    

    # clone the real activations
    pre_predicted_latents_tensor = active_latents_tensor.clone()
    #non-zero elements
    latents_in_tensor = sentence_data[latents_in]["feature"].tolist()
    pre_predicted_latents_tensor[latents_in_tensor] = torch.tensor(wanted_predicted_activations,
                                                            dtype=autoencoder_latents.dtype,
                                                            device=autoencoder_latents.device)
    top_latents = torch.argsort(pre_predicted_latents_tensor,descending=True)[:k]
    
    predicted_latents_tensor = torch.zeros_like(pre_predicted_latents_tensor,
                                                dtype=autoencoder_latents.dtype,
                                                device=autoencoder_latents.device)
    predicted_latents_tensor[top_latents] = pre_predicted_latents_tensor[top_latents]
    
    return predicted_latents_tensor

def compute_loss_for_latents(predicted_latents_tensor, model, prompt, autoencoder, real_layer_activation, layer, device, sae_type,):
    with torch.no_grad():
        with model.trace(prompt[:-1]):
            # take the predicted latents vector, decode it and add the skip connection
            reconstructed = predicted_latents_tensor@autoencoder.W_dec + autoencoder.b_dec
            skip = autoencoder.W_skip.to(torch.float16)
            reconstructed += real_layer_activation[-1,:]@skip.mT
            model.model.layers[layer].mlp.output[0,-1,:] = reconstructed
            
            simulated_output = model.output.save()

    loss = torch.nn.functional.cross_entropy(simulated_output.logits[0,-1,:],
                                            torch.tensor(prompt[-1],device=device),
                                            reduction="mean")
  
    return loss.item()
def compute_loss_for_mlp(model, prompt,mean_activations):
    with torch.no_grad():
        with model.trace(prompt[:-1]):
            model.model.layers[15].mlp.down_proj.input[0,-1,:] = mean_activations[:]
            
            simulated_output = model.output.save()
   
    loss = torch.nn.functional.cross_entropy(simulated_output.logits[0,-1,:],
                                            torch.tensor(prompt[-1],device="cuda"),
                                            reduction="mean")
    logits = simulated_output.logits[0,-1,:].cpu()
    return loss.item(),logits

def load_artifacts():
    """
    Loads the data for the experiment.
    It expects that we already have a trained sparse coder for the layer we want to use.
    It also expects that we have already generated the explanations and their scores if we want to use them.
    We also expect to have the predictions from the simulator into a single .csv file. 
    """
        
    # we hard code the names of the files for anonymity
    all_data = pd.read_csv(f"transcoder_activations.csv")
    all_data["activation"] = all_data["normalized_activation"]
    
    with open(f"results/saved_tokens.json", "r") as f:
        text_tokens = json.load(f)
    
    # Local files generated by the scorer models.        
    recall_scores = pd.read_csv(f"results/scores_detection_transcoder.csv")
    fuzz_scores = pd.read_csv(f"results/scores_fuzz_transcoder.csv")
    # merge on latent
    all_scores = recall_scores.merge(fuzz_scores, on="latent", how="left")
    #average_scores
    all_scores["f1"] = (all_scores["f1_x"] + all_scores["f1_y"])/2
    all_scores = all_scores.sort_values(by="f1", ascending=False)
    
    
    path = "checkpoints/transcoder"
    hookpoint = f"layers.15.mlp"

    autoencoder = SparseCoder.load_from_disk(path+"/"+hookpoint	, device="cuda")

    
    return all_data,text_tokens,autoencoder,all_scores

def transcoder_reconstruction(autoencoder,activations):
    encoder_output  = autoencoder.encode(activations[0].to(torch.float32))
    indices = encoder_output.top_indices
    values = encoder_output.top_acts
    autoencoder_latents = torch.zeros_like(encoder_output.pre_acts)
    autoencoder_latents.scatter_(-1, indices, values).to(torch.float16)
    return autoencoder_latents

def main(args):
    all_data,text_tokens,autoencoder,all_scores = load_artifacts()
    # we used SmollM but also pythia, the idea rest of the code is similar but needs to be adjusted for Pythia
    model = LanguageModel("HuggingFaceTB/SmolLM2-135M", device_map="cuda", dispatch=True,torch_dtype="float16")

    prompts = text_tokens.values()
    unique_latents = all_scores["latent"].unique()
    sentence_groups = dict(tuple(all_data.groupby("sentence_idx")))
    all_data.fillna(0, inplace=True)
    if args.mean_ablation_latents or args.mean_ablation:
            
        if args.mean_ablation_latents:
            # 18432 is the number of latents for the transcoder
            activations = torch.zeros((len(prompts),18432))
        else:
            activations = torch.zeros((len(prompts),model.model.layers[15].mlp.down_proj.in_latents))
        for i,prompt in tqdm(enumerate(prompts)):
            actual_prompt = prompt[:-1]
            # collect the activations
            with torch.no_grad():
                with model.trace(actual_prompt):
                    if args.mean_ablation_latents:
                        real_layer_activation = model.model.layers[15].input[0].save()
                    else:
                        real_layer_activation = model.model.layers[15].mlp.down_proj.input[0].save()
                    output = model.output.save()
            # compute the latent activations for the transcoder
            if args.mean_ablation_latents:
                autoencoder_latents = transcoder_reconstruction(autoencoder,real_layer_activation)
                activations[i,:] = autoencoder_latents[-1]
            else:
                activations[i,:] = real_layer_activation[-1]
        mean_activations = torch.mean(activations,dim=0).half()

    for fraction in args.fraction:
        loss_base = []
        loss_reconstruction = []
        loss_prediction = []
       
        for i,prompt in tqdm(enumerate(prompts)):
            sentence_data = sentence_groups[i]
            actual_prompt = prompt[:-1]
            with torch.no_grad():
                with model.trace(actual_prompt):
                    real_layer_activation = model.model.layers[15].mlp.input[0].save()
                    output = model.output.save()
    
            # base loss for the model
            ba_loss = torch.nn.functional.cross_entropy(output.logits[0,-1,:],torch.tensor(prompt[-1],device="cuda"),reduction="mean").item()
            
            autoencoder_latents = transcoder_reconstruction(autoencoder,real_layer_activation)


            re_loss,re_logits = compute_loss_for_latents(active_latents_tensor,model, prompt, autoencoder, real_layer_activation)
            
            active_latents_tensor = autoencoder_latents.clone()[-1,:]
            
            num_latents = max(int(fraction * len(unique_latents)),1)
            if args.mean_ablation_mlp:
                loss = compute_loss_for_mlp(model, prompt,mean_activations)
            else:
                if args.sample: 
                    # sample the latents to be predicted
                    latents = np.random.choice(unique_latents,num_latents,replace=False)
                    latents.sort()
                    predicted_latents_tensor = get_predicted_latents_topk(sentence_data, latents, active_latents_tensor, autoencoder_latents,k=128)
                    
                elif args.mean_ablation:
                    # sample which
                    latents = np.random.choice(unique_latents,num_latents,replace=False)
                    active_latents_tensor[latents] = mean_activations[latents].cuda().float()
                    predicted_latents_tensor = active_latents_tensor
                else:
                    # use top scoring latents to be predicted
                    latents = all_scores["latent"][:num_latents].tolist()
                    # sort the list to get the latents in order
                    latents.sort()
                    
                    predicted_latents_tensor = get_predicted_latents_topk(sentence_data, latents, active_latents_tensor, autoencoder_latents,k=128)
                    
                loss = compute_loss_for_latents(predicted_latents_tensor, model, prompt,
                                            autoencoder, real_layer_activation)
                del predicted_latents_tensor
            #print(loss)
            loss_prediction.append(loss)
            loss_reconstruction.append(re_loss)
            loss_base.append(ba_loss)
            
            del real_layer_activation, autoencoder_latents, output,active_latents_tensor
            # clear gpu memory
            torch.cuda.empty_cache()
        # save the results
        name = f"{fraction}"
        if args.sample:
            name += "_sample"
        if args.mean_ablation:
            name += "_mean_ablation"
        if args.mean_ablation_mlp:
            name += "_mean_ablation_mlp"
        print("Name: ",name)
        print(f"Loss base: {np.mean(loss_base),np.median(loss_base)}",len(loss_base))
        print(f"Loss reconstruction: {np.mean(loss_reconstruction),np.median(loss_reconstruction)}",len(loss_reconstruction))
        print(f"Loss prediction: {np.mean(loss_prediction),np.median(loss_prediction)}",len(loss_prediction))
        # compute kl div
        
        with open(f"ce_loss_{name}.json", "w") as f:
            json.dump({"loss_base":loss_base,"loss_reconstruction":loss_reconstruction,"loss_prediction":loss_prediction},f)

    
if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--fraction", type=float, nargs="+", default=[1])
    parser.add_argument("--sample", action="store_true")
    parser.add_argument("--mean_ablation_latents", action="store_true")
    parser.add_argument("--mean_ablation_mlp", action="store_true")
    args = parser.parse_args()
    
    main(args)