from datasets import load_dataset
from modelHelper import *
from nltk.tokenize import word_tokenize
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from transformers import pipeline
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import set_seed
import code
import nltk
import torch
import transformers
import os
import csv
import json
from sacrebleu.metrics import BLEU  # pip install sacrebleu
import numpy as np
import gc

nltk.download('punkt_tab')

def append_list_to_csv(row: list, filename: str):
    """
    Appends a list of values as a new row to a CSV file.
    Creates the file if it does not exist.

    Args:
        row (list): The list of values to write as a single CSV row.
        filename (str): Path to the CSV file.
    """

    # Open the file in append mode, creating it if necessary
    with open(filename, 'a', newline='', encoding='utf-8') as csvfile:
        writer = csv.writer(csvfile)

        # If the file did not exist, you could write a header here
        # (uncomment and customize the next two lines if you need headers)
        # if not file_exists:
        #     writer.writerow(['column1', 'column2', 'column3', ...])

        # Write the provided list as a new row
        writer.writerow(row)

class SentimentClassification:
    def __init__(self, sent_model_name="textattack/bert-base-uncased-yelp-polarity"):
        # 1. Load the RoBERTa sentiment‐classification checkpoint (once, outside the loop)
        #    Here we use "cardiffnlp/twitter-roberta-base-sentiment", which outputs three labels:
        #    0 = negative, 1 = neutral, 2 = positive.
        #    0 = negative, 1 = neutral, 2 = positive.
        self.sent_model_name  = sent_model_name
        self.sent_tokenizer = AutoTokenizer.from_pretrained(sent_model_name)
        self.sent_model = AutoModelForSequenceClassification.from_pretrained(sent_model_name)
        self.sent_model.eval()
        self.sentiment_task = pipeline(
            "sentiment-analysis",
            model=sent_model_name,
            tokenizer=sent_model_name, 
            truncation=True,
            max_length=512
        )
        if sent_model_name=="cardiffnlp/twitter-roberta-base-sentiment":
            self.labelToSentiment = {
                "LABEL_0": "Negative"
                ,"LABEL_1": "Neutral"
                ,"LABEL_2": "Positive"
            }
        elif sent_model_name=="textattack/bert-base-uncased-yelp-polarity":
            self.labelToSentiment = {
                "LABEL_0": "Negative"
                ,"LABEL_1": "Positive"
            }
        else:
            raise ExceptionType("No labelToSentiment defined for specified model\n this changes for every model so need to ensuere you have the correct mapping")
            



    def classify_sentiment(self, text: str) -> str:
        ret  = self.sentiment_task(text)
        return self.labelToSentiment[ret[0]['label']], ret[0]["score"] # assume single example

    # def classify_sentiment(self, text: str) -> str:
        # """
        # Given a piece of text, tokenize it and run it through a RoBERTa-based
        # sentiment classifier. Returns one of "negative", "neutral", or "positive".
        # """
        # # 1. Tokenize (truncate at 512 tokens to fit most long reviews)
        # tokens = self.sent_tokenizer(
            # text,
            # return_tensors="pt",
            # truncation=True,
            # max_length=512
        # ).to(self.sent_model.device)

        # # 2. Forward pass (no gradients)
        # with torch.no_grad():
            # outputs = self.sent_model(**tokens)
            # logits = outputs.logits[0].cpu().numpy()

        # # 3. Softmax to get probabilities
        # probs = np.exp(logits) / np.exp(logits).sum()
        # label_idx = int(np.argmax(probs))

        # # 4. Map index → label
        # if label_idx == 0:
            # return "negative"
        # elif label_idx == 1:
            # return "neutral"
        # else:
            # return "positive"

# def compute_bleu4(reference: str, candidate: str) -> float: # this is the old implementation that uses nlkt
    # """
    # Compute the BLEU-4 (cumulative 4-gram) score between a reference string
    # and a candidate string.

    # Args:
        # reference (str): The original text (reference).
        # candidate (str): The rephrased text (candidate).

    # Returns:
        # float: BLEU-4 score (in the range [0, 1]) * 100.
    # """
    # # 1. Tokenize on whitespace (or use word_tokenize for more robust tokenization)
    # ref_tokens = word_tokenize(reference.lower())
    # cand_tokens = word_tokenize(candidate.lower())

    # # 2. Define 4-gram weights for cumulative BLEU-4
    # weights = (0.25, 0.25, 0.25, 0.25)

    # # 3. Use a smoothing function to avoid zero scores on short sentences
    # smoothing_fn = SmoothingFunction().method4

    # # 4. Compute BLEU score
    # bleu_score = sentence_bleu(
        # [ref_tokens],         # list of reference token lists
        # cand_tokens,          # candidate token list
        # weights=weights,
        # smoothing_function=smoothing_fn,
    # )

    # return bleu_score * 100.
    

def compute_bleu4(reference: str, candidate: str) -> float:
    """
    Compute the BLEU‑4 (cumulative 4‑gram) score between a reference string
    and a candidate string using SacreBLEU.

    Args:
        reference (str): The original text (reference).
        candidate (str): The rephrased text (candidate).

    Returns:
        float: BLEU‑4 score on the 0‑to‑100 scale.
    """
    # Initialise a BLEU metric object: BLEU‑4, case‑insensitive, WMT‑style tokeniser,
    # with exponential smoothing so short sentences don’t collapse to 0.
    bleu = BLEU(
        max_ngram_order=4,
        tokenize="13a",      # same default used by WMT leaderboards
        lowercase=True,
        smooth_method="exp"  # recommended smoothing for sentence BLEU
    )

    # SacreBLEU expects a list of references.
    score = bleu.sentence_score(candidate, [reference]).score
    return score


prettyParam = {
    "textattack/bert-base-uncased-yelp-polarity": "bertBase"
    ,"meta-llama/Meta-Llama-3.1-8B-Instruct": "llama-8b-In"
}
def create_param_folder(params, base_dir="results"):
    """
    Given a dictionary of parameters, construct a sanitized folder name,
    create that directory under `base_dir`, and return the full path.
    """
    def sanitize(val):
        s = str(val)
        # Replace path separators and spaces with underscores
        for ch in ['/', '\\', ' ']:
            s = s.replace(ch, '_')
        # Handle None explicitly
        if s == "None":
            s = "None"
        return s

    parts = []
    for key, val in params.items():
        if val is not None:
            parts.append(f"{key}-{sanitize(val)}")
    folder_name = "_".join(parts)

    folder_path = os.path.join(base_dir, folder_name)
    os.makedirs(folder_path, exist_ok=True)
    return folder_path



def runSimBatch(
    dataSplit="test" # trian | test
    ,numberSamples=None # number of samples to include in dataset
    ,sent_model_name="textattack/bert-base-uncased-yelp-polarity" # sentiment classification model
    ,model_name="meta-llama/Meta-Llama-3.1-8B-Instruct" # language model type -  "meta-llama/Meta-Llama-3.1-8B-Instruct" 33 layers | "meta-llama/Llama-3.3-70B-Instruct" # 81 layers
    ,layerIdx=2 # layer to index at
    ,w=0.35 # weight of index
    ,indexApplicationType="standard" # index applicaiton type standard | add   (typically wan tto use standard)
    ,base_dir="./outputs/"
    ,seed=0
    ,max_mem=None
    ,batch_size=32
    ):

    set_seed(seed)
    params = {
        "dataSplit": dataSplit,
        "numberSamples": numberSamples,
        "model_name": prettyParam.get(model_name),
        "sent_model_name": prettyParam.get(sent_model_name),
        "indexApplicationType": indexApplicationType,
        "seed":seed,
        "layerIdx": layerIdx,
        "w": w,
    }

    outputFolderPath = create_param_folder(params, base_dir=base_dir)

    json_path = os.path.join(outputFolderPath, "params.json")
    with open(json_path, "w") as jf:
        json.dump(params, jf, indent=4)

    outputFile = os.path.join(outputFolderPath, "data.csv")

    # 1. Load the Yelp sentiment benchmark dataset (polarity)
    #    Each sample has 'text' and 'label' (0 = negative, 1 = positive).

    # dataset = load_dataset("yelp_review_polarity", split="train")  # or "test" if you prefer
    dataset = load_dataset("yelp_polarity", split=dataSplit) # or "test" if you prefer
    numberToLabel={0:"Negative", 1:"Positive"} # map from dataset integers to labels
    dataset = dataset.shuffle(seed=0) # static order
    # Grab the first 100 examples
    if numberSamples is not None:
        dataset = dataset.select(range(numberSamples))


    sc = SentimentClassification(sent_model_name=sent_model_name)

    # If you want to process both train and test:
    # train_ds = load_dataset("yelp_review_polarity", split="train")
    # test_ds  = load_dataset("yelp_review_polarity", split="test")
    # dataset = train_ds.concatenate(test_ds)

    # 2. Load the Llama-8B Instruct model and tokenizer
    #    Replace the model name with the exact repo you want to use.

    # model_name = "TheBloke/llama-8b-instruct-v2"  # e.g. "TheBloke/llama-8b-instruct-v2"
    # model_name="meta-llama/Meta-Llama-3.1-8B-Instruct" # 33 layers
    # model_name = "meta-llama/Llama-3.3-70B-Instruct" # 81 layers
    tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id
        tokenizer.pad_token    = tokenizer.eos_token
    
    if max_mem is None:
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            device_map="auto",           # automatically place layers on available GPUs/CPU
            # torch_dtype=torch.float16,   # use FP16 for large model
        )
    else: # specify memory usage for effeciency
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.bfloat16,
            device_map="sequential",      # fill GPU‑0, then GPU‑1, …
            max_memory=max_mem,
            low_cpu_mem_usage=True        # avoid a second copy of the state‑dict
        )
    model.eval()

    pipeline = transformers.pipeline(
    "text-generation",
    # model=model_id,
    model=model,
    tokenizer=tokenizer,
    model_kwargs={"torch_dtype": torch.bfloat16},
    # device_map="auto",
    # device=device,
    )
    pipeline.batch_size = batch_size

    negIndexToken = get_last_token_representation(model, tokenizer, "be extremely negative", layerIdx)
    posIndexToken = get_last_token_representation(model, tokenizer, "be extremely positive", layerIdx)

    append_list_to_csv(["Orig Text","Rephrased Text", "Ground Truth Label", "Orig Pred Label", "Rephrased Pred Label", "Orig Pred Score", "Rephrased Pred Score", "Bleu"], outputFile)

    # 3. Iterate over every sample, send to LLM to rephrase, and print original/rephrased/label
    #    For large datasets, consider writing to a file rather than printing to console.
    batch = []
    for i,sample in tqdm(enumerate(dataset), desc="Processing samples"):
        text = sample["text"].strip()
        label = sample["label"]

        if i < len(dataset) -1 and len(batch) < batch_size: # if not at the end and batch is too small then increase batch size
            # 3 Do chat with chatbot
            messages = [
                {"role": "system", "content": "You are a helpful assistant. The user will provide you with a review, respond only with a rephrased review (with no additional commentary) while keeping the original meaning.\n\n"}
                ,{"role": "user", "content": f"{text}"}
                # ,{"role": "user", "content": f"Review: \"{text}\"\n\n"}
            ]
            batch.append(messages)
            continue


        # # 3a. Construct a prompt to instruct the model to rephrase the review
        # prompt = (
            # "You are a helpful assistant. Below is a review, respond with only a rephrased review (with no additional commentary) while keeping the original meaning.\n\n"
            # f"Review: \"{text}\"\n\n"
            # "Rephrased Review:"
        # )

        # # 3b. Tokenize and generate
        # inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024).to(model.device)
        # with torch.no_grad():
            # output_ids = model.generate(
                # **inputs,
                # max_new_tokens=256,
                # do_sample=False,       # deterministic output (greedy). Set to True for sampling.
                # eos_token_id=tokenizer.eos_token_id,
                # pad_token_id=tokenizer.eos_token_id,
            # )

        # # 3c. Decode the generated tokens, removing the prompt prefix
        # generated = tokenizer.decode(output_ids[0], skip_special_tokens=True)

        # # The model's output includes the prompt, so we strip everything before "Rephrased Review:"
        # if "Rephrased Review:" in generated:
            # rephrased = generated.split("Rephrased Review:")[-1].strip()
        # else:
            # # fallback in case the split fails
            # rephrased = generated[len(prompt) :].strip()

        
        
        # # 3 Do chat with chatbot
        # messages = [
            # {"role": "system", "content": "You are a helpful assistant. The user will provide you with a review, respond only with a rephrased review (with no additional commentary) while keeping the original meaning.\n\n"}
            # ,{"role": "user", "content": f"{text}"}
            # # ,{"role": "user", "content": f"Review: \"{text}\"\n\n"}
        # ]

        # outputs = pipeline(
            # messages,
            # # max_new_tokens=256,
            # max_new_tokens=1024,
            # )
        # rephrased = outputs[0]["generated_text"][-1]["content"]

        # 1 means postive sample so we want to use opposite index token from ground trooth
        indexToken = negIndexToken if label == 1 else posIndexToken
        original_forward = model.model.layers[layerIdx].forward # save origional forward pass operation
        if w > 0.:
            updateModelForwardPass(model, layerIdx, indexToken, w, indexApplicationType)
        outputs = pipeline(
            batch,
            # messages,
            # max_new_tokens=256,
            max_new_tokens=1024,
            batch_size=batch_size,
            )
        code.interact(local=dict(globals(), **locals()))
        rephrased = outputs[0]["generated_text"][-1]["content"]
        model.model.layers[layerIdx].forward = original_forward # resotre origional forward pass operation
        
        blue = compute_bleu4(text, rephrased)

        origPredLabel, origPredScore = sc.classify_sentiment(text)
        rephrasedPredLabel, rephrasedPredScore = sc.classify_sentiment(rephrased)
        

        # 3d. Print original review, rephrased review, and label
        print("=== Sample ===")
        print(f"Original Review:\n {text}")
        print("---")
        print(f"Rephrased Review:\n {rephrased}")
        print("---")
        print(f"Label:\n {numberToLabel[label]}")
        print("---")
        print(f"Blue:\n {blue}")
        print("---")
        print(f"origPredLabel, origPredScore:\n {origPredLabel, origPredScore}")
        print("---")
        print(f"rephrasedPredLabel, rephrasedPredScore:\n {rephrasedPredLabel, rephrasedPredScore}")
        print()
        append_list_to_csv([text,rephrased,numberToLabel[label], origPredLabel, rephrasedPredLabel, origPredScore, rephrasedPredScore, blue], outputFile)
        batch = []

    del pipeline
    del model          # remove Python refs
    del sc
    del tokenizer
    torch.cuda.empty_cache()  # tells the caching allocator to release unused blocks
     # 1) Force Python’s GC to run so that finalizers (__del__) fire
    gc.collect()

    # 2) Ask CUDA runtime to collect inter‑process handles
    torch.cuda.ipc_collect()

    # 3) Hand any totally‑unused blocks back to the driver
    torch.cuda.empty_cache()


if __name__ == "__main__":
    
    # runSim(
        # dataSplit="train" # trian | test
        # ,numberSamples=100 # number of samples to include in dataset

        # # ,sent_model_name="cardiffnlp/twitter-roberta-base-sentiment" # sentiment classification model
        # ,sent_model_name="textattack/bert-base-uncased-yelp-polarity" # sentiment classification model

        # ,model_name="meta-llama/Meta-Llama-3.1-8B-Instruct" # language model type -  "meta-llama/Meta-Llama-3.1-8B-Instruct" 33 layers | "meta-llama/Llama-3.3-70B-Instruct" # 81 layers
        # ,layerIdx=2 # layer to index at
        # ,w=0.35 # weight of index
        # ,indexApplicationType="standard" # index applicaiton type standard | add   (typically wan tto use standard)
        # ,base_dir="./outputs/"
        # ,seed=0
        # )
    


    # # llama 8b
    baselineExecuted = False # be sure to execute baseline w == 0. only one time

    # # orig params
    # for layerIdx in [1, 3, 5, 10, 15, 20, 25, 30, 31]:
        # for w in [float(w) for w in np.arange(0., 0.5, 0.05)]:

    for layerIdx in [1, 3, 5, 10]:
        for w in [float(w) for w in np.arange(0.2, 0.6, 0.025)] + [0.]:

            if w == 0. and baselineExecuted:
                continue
            baselineExecuted = True
            # runSim(
            runSimBatch(
                # dataSplit="train" # trian | test
                # ,numberSamples=100 # number of samples to include in dataset

                dataSplit="test" # trian | test
                ,numberSamples=None # number of samples to include in dataset

                # ,sent_model_name="cardiffnlp/twitter-roberta-base-sentiment" # sentiment classification model
                ,sent_model_name="textattack/bert-base-uncased-yelp-polarity" # sentiment classification model

                ,model_name="meta-llama/Meta-Llama-3.1-8B-Instruct" # language model type -  "meta-llama/Meta-Llama-3.1-8B-Instruct" 33 layers | "meta-llama/Llama-3.3-70B-Instruct" # 81 layers
                ,layerIdx=layerIdx # layer to index at
                ,w=w # weight of index
                ,indexApplicationType="standard" # index applicaiton type standard | add   (typically wan tto use standard)
                ,base_dir="./llama8b/"
                ,seed=0
                ,batch_size=32
                )





    # # llama 70b
    # baselineExecuted = False

    # # # orig params
    # # for layerIdx in [1, 5, 15, 30, 40, 50, 65, 75, 79]:
        # # for w in [float(w) for w in np.arange(0., 0.5, 0.05)]:

    # for layerIdx in [1, 5, 15, 30]:
        # for w in [float(w) for w in np.arange(0.3, 0.6, 0.025)] + [0.]:

            # if w == 0. and baselineExecuted:
                # continue
            # baselineExecuted = True
            # runSim(
                # # dataSplit="train" # trian | test
                # # ,numberSamples=100 # number of samples to include in dataset

                # dataSplit="test" # trian | test
                # ,numberSamples=None # number of samples to include in dataset

                # # ,sent_model_name="cardiffnlp/twitter-roberta-base-sentiment" # sentiment classification model
                # ,sent_model_name="textattack/bert-base-uncased-yelp-polarity" # sentiment classification model

                # ,model_name="meta-llama/Llama-3.3-70B-Instruct" # language model type -  "meta-llama/Meta-Llama-3.1-8B-Instruct" 33 layers | "meta-llama/Llama-3.3-70B-Instruct" # 81 layers
                # ,layerIdx=layerIdx # layer to index at
                # ,w=w # weight of index
                # ,indexApplicationType="standard" # index applicaiton type standard | add   (typically wan tto use standard)
                # ,base_dir="./llama70b/"
                # ,seed=0
                # ,max_mem = { # hertz specific max memory usage
                        # 0: "43GiB",   # A40 #0
                        # 1: "43GiB",   # A40 #1
                        # 2: "43GiB",   # A40 #2
                        # 3: "78GiB",   # A100 80 GB – leave ~2 GB cushion
                        # "cpu": "0GiB" # disallow RAM off‑load
                    # }
                # )
