import asyncio
import json
import os
import time
import pickle
from functools import partial

import numpy as np
import orjson
import torch

from safetensors.numpy import load_file
from simple_parsing import ArgumentParser
from tqdm import tqdm

from delphi.clients import Offline
from delphi.latents import Example, Latent, LatentRecord, LatentDataset
from delphi.pipeline import Pipe, Pipeline, process_wrapper
from delphi.config import SamplerConfig, ConstructorConfig
from transformers import AutoTokenizer
import glob
from full_sim.special_simulator import Simulator

import pandas as pd

def build_feature_record(selected_tokens ,activation,explanation=None,order=-1,sentence_idx=-1,feature=0):
    feature = Latent("layer", feature)
    activations = torch.zeros_like(selected_tokens,dtype=torch.float32) # The real activations don't matter
    activations[-1] = activation
    example = Example(selected_tokens, activations)
    feature_record = LatentRecord(feature)
    feature_record.explanation = explanation
    feature_record.test = [example]
    feature_record.sentence_idx = sentence_idx
    feature_record.order = order
    return feature_record

def construct_windows(sentence_idx, tokens, locations, activation):
    window_idx = np.random.randint(32+1,256-32)
    selected_window = tokens[sentence_idx][window_idx-32:window_idx]
    
    selected_sentence_idx = locations[:,0]==sentence_idx
    sentence_locations = locations[selected_sentence_idx]
    sentence_activation = activation[selected_sentence_idx]
    
    token_idx = sentence_locations[:,1]==window_idx-1 # selected tokens does not include the last token
    locations_at_token = sentence_locations[token_idx]
    activation_at_token = sentence_activation[token_idx]
    active_features = locations_at_token[:,2]

    # sort the active features by the activation
    index_sort = torch.argsort(activation_at_token)
    sorted_active_features = active_features[index_sort]
    sorted_activation_at_token = activation_at_token[index_sort]
    return selected_window, sorted_active_features, sorted_activation_at_token

def create_records(sentence_idx, tokens, locations, activation, all_explanations,latents, SCORES_FOLDER):
    selected_window, sorted_active_features, sorted_activation_at_token = construct_windows(sentence_idx, tokens, locations, activation)

    feature_records = []
    # Create records for the active features

    for latent in latents:
        latent = int(latent)
        if latent not in sorted_active_features:
            activation = 0
            order = -1
        else:
            order = torch.where(sorted_active_features==latent)[0][0].item()
            activation = sorted_activation_at_token[order]
            
        if str(latent) not in all_explanations:
            continue
        explanation = all_explanations[str(latent)]
        score = ""
        explanation = explanation+score
        feature_record = build_feature_record(selected_window,explanation=explanation,activation=activation,order=order,sentence_idx=sentence_idx,feature=latent)
        feature_records.append(feature_record)
    os.makedirs(SCORES_FOLDER+f"{sentence_idx}", exist_ok=True)
    return feature_records


async def feature_generator(pre_feature_records, tokens, locations, activation, all_explanations,latents, number_of_sentences, start_sentence, SCORES_FOLDER):
 
    
    for sentence_idx in range(start_sentence,start_sentence+number_of_sentences):
        if pre_feature_records:
            selected_window, sorted_active_features, sorted_activation_at_token = construct_windows(sentence_idx, tokens, locations, activation)
            for feature_record in pre_feature_records:
                latent = int(feature_record.latent)
                if latent not in sorted_active_features:
                    activation = 0
                    order = -1
                else:
                    order = torch.where(sorted_active_features==latent)[0][0].item()
                    activation = sorted_activation_at_token[order]
                activations = torch.zeros_like(selected_window,dtype=torch.float32) # The real activations don't matter
                activations[-1] = activation
                example = Example(selected_window, activations)
                feature_record.test = [example]
                feature_record.sentence_idx = sentence_idx
                feature_record.order = order
                yield feature_record
        else:
            feature_records = create_records(sentence_idx, tokens, locations, activation, all_explanations,latents, window_size, SCORES_FOLDER)
            for feature_record in feature_records:
                yield feature_record
# Make postprocess for the fuzzing scorer
def scorer_postprocess(result,output_folder,quantile_stats):
    
    predicted_quantile = result.score[0].predicted_quantile
    activation = result.score[0].activation.item()
    expected_quantile = result.score[0].expected_quantile
    sentence_idx = result.record.sentence_idx
    order = result.record.order 
    text = result.score[0].text
    feature = str(result.record.latent).split("latent")[-1]
    quantile_info = quantile_stats[feature]
  
    predicted_activation = quantile_info[-1]*expected_quantile/9
            
    # print the types of the variables
    try:
        saving_result = {
            #"text": str(text),
            "feature": int(feature),
            "predicted_quantile": float(predicted_quantile),
            "activation": float(activation),
            "order": int(order),
            "expected_quantile": float(expected_quantile),
            "predicted_activation": float(predicted_activation)
        }
    except:
        print(f"Error with {sentence_idx} {feature}")
        saving_result = {
            #"text": "Error",
            "feature": int(feature),
            "predicted_quantile": -1,
            "activation":  float(activation),
            "order": int(order),
            "expected_quantile": -1,
            "predicted_activation": -1
        }
    with open(f"{output_folder}{sentence_idx}/{feature}.txt", "wb") as f:
            f.write(orjson.dumps(saving_result))
    return result
    
def merge_scores(SCORES_FOLDER):
    all_files = glob.glob(os.path.join(SCORES_FOLDER, "*.txt"))
    all_results = []
    for file in all_files:
        with open(file, "r") as f:
            data = json.load(f)
            all_results.append(data)
    
    all_data = pd.DataFrame(all_results)
    all_data.to_csv(f"{SCORES_FOLDER}all_data.csv",index=False)
    # remove all files
    for file in all_files:
        os.remove(file)
    return all_data

def get_quantile_stats(locations,activations,high_scores):
    index_sort = torch.argsort(locations[:,2].cuda()).cpu()
    sorted_locations = locations[index_sort].cuda()
    sorted_activations = activations[index_sort].cuda()
    
    quantile_stats = {}
    start_idx = 0
    for latent in tqdm(high_scores):
        #start_idx = torch.searchsorted(sorted_locations[:,2], latent)
        end_idx = torch.searchsorted(sorted_locations[:,2], int(latent) + 1)
        # Get all activations for this feature
        feature_activation = sorted_activations[start_idx:end_idx]
        #feature_locations = sorted_locations[start_idx:end_idx]
        sorted_feature_activations = torch.sort(feature_activation,descending=True)[0]
        #print("Sorting activations took: ",time.time()-start)
        if len(sorted_feature_activations) == 0:
            start_idx = end_idx
            continue
        else:
            linearly_spaced_activations = [0]+torch.linspace(sorted_feature_activations[-1],sorted_feature_activations[0],9).tolist()
            quantile_stats[latent] = linearly_spaced_activations
            start_idx = end_idx
    return quantile_stats



def main(args):
    
    all_locations = []
    all_activations = []
    ranges = ["0_3685","3686_7371","7372_11058","11059_14744","14745_18431"]
    tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M")
   
    tokens = None
    for valid_range in ranges:
        split_data = load_file(f"results/transcoder/latents/layers.15.mlp/{valid_range}.safetensors")
        activations = torch.tensor(split_data["activations"])
        locations = torch.tensor(split_data["locations"].astype(np.int32))
        locations[:,2] = locations[:,2]+int(valid_range.split("_")[0])
        all_locations.append(locations)
        all_activations.append(activations)
        if tokens is None:
            tokens = torch.tensor(split_data["tokens"])


    locations = torch.cat(all_locations)
    activation = torch.cat(all_activations)
    with open(f"results/explanations_transcoder.json", "r") as f:
            #load json file
            all_explanations = json.load(f)
   

    latents = list(all_explanations.keys())

    if os.path.exists(f"results/quantile_stats_transcoder.pkl"):
        with open(f"results/quantile_stats_transcoder.pkl", "rb") as f:
            quantile_stats = pickle.load(f)
    else:
        quantile_stats = get_quantile_stats(locations,activation,latents)
        with open(f"results/quantile_stats_transcoder.pkl", "wb") as f:
            pickle.dump(quantile_stats, f)
    
    if args.use_examples:
        sampler_cfg = SamplerConfig(n_examples_train=20,n_examples_test=0)
        constructor_cfg = ConstructorConfig(n_non_activating=0)
        hookpoints = ["layers.15.mlp"]
        latents = {"layers.15.mlp":torch.arange(0,18432)}
        latent_dataset = LatentDataset(
            raw_dir=f"results/transcoder/latents",
            sampler_cfg=sampler_cfg,
            constructor_cfg=constructor_cfg,
            modules=hookpoints,
            latents=latents,
            tokenizer=tokenizer,
        )
    else:
        latent_dataset = None

    
    client = Offline("hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4",max_memory=0.85, max_model_len=2192,num_gpus=2,batch_size=800)
    SCORES_FOLDER= f"results/kl_div_activations/70b/transcoder/"
        
    np.random.seed(42)

    fuzz_scorer = Simulator(client, tokenizer,log_prob=True,verbose=False)
    # make folder for the scores
            
    scorer_pipe = Pipe(process_wrapper(fuzz_scorer, postprocess=partial(scorer_postprocess,output_folder=SCORES_FOLDER,quantile_stats=quantile_stats)))
    
    

    if latent_dataset is not None:
        feature_records = []
        # check if feature_records exists
        if os.path.exists(f"results/feature_records_transcoder.pkl"):
            with open(f"results/feature_records_transcoder.pkl", "rb") as f:
                feature_records = pickle.load(f)
        else:
            for record in tqdm(latent_dataset):
                record.examples = []
                feature = Latent("layer", record.latent.latent_id)
                new_record = LatentRecord(feature)
                new_record.train = [record.train]
                feature_records.append(new_record)
                del record
            with open(f"results/feature_records_transcoder.pkl", "wb") as f:
                pickle.dump(feature_records, f)
    else:
        feature_records = None
    pipeline = Pipeline(
        feature_generator(feature_records, tokens, locations, activation, all_explanations,latents, args.num_sentences, args.start_sentence, SCORES_FOLDER),
        scorer_pipe,
    )

    asyncio.run(pipeline.run(10000))
    all_data = []
    for i in tqdm(range(args.start_sentence,args.start_sentence+args.num_sentences)):
        data = merge_scores(f"{SCORES_FOLDER}{i}")
        all_data.append(data)
    all_data = pd.concat(all_data)
    
    
if __name__ == "__main__":
    parser = ArgumentParser()
    #parser.add_argument("--sentence_idx", type=int, default=0)
    parser.add_argument("--model_size", type=str, default="70b")
    parser.add_argument("--num_sentences", type=int, default=1000)
    parser.add_argument("--start_sentence", type=int, default=0)
    parser.add_argument("--use_examples", action="store_true", default=False)
    args = parser.parse_args()
    #sentence_idx = args.sentence_idx

    main(args)