from pprint import pprint, PrettyPrinter
import csv
import matplotlib.pyplot as plt
import numpy as np
import os
# corpus_bleu4.py
from typing import List, Sequence, Tuple, Union
from sacrebleu.metrics import BLEU
import code

Text  = str
Cand  = Text
Ref   = Text
Refs  = Sequence[Text]
Pair  = Union[Tuple[Cand, Ref], Tuple[Cand, Refs]]

def compute_corpus_bleu4(
        pairs: List[Pair],
        lowercase: bool = True,
        tokenize: str = "13a",
        smooth_method: str = "exp"
    ) -> float:
    """
    Compute **corpus BLEU‑4** (0–100 scale) for a list of text pairs.

    Parameters
    ----------
    pairs : List[Pair]
        (candidate, reference)  or  (candidate, [reference1, reference2, …])
    lowercase, tokenize, smooth_method
        Passed straight through to SacreBLEU; max_ngram is fixed to 4.
    Returns
    -------
    float
        BLEU‑4 × 100 (e.g. 27.3 instead of 0.273).
    """
    # Split candidates and references
    cands, ref_lists = [], []
    for cand, refs in pairs:
        cands.append(cand)
        ref_lists.append(refs if isinstance(refs, Sequence) and not isinstance(refs, str)
                         else [refs])

    # Transpose to match SacreBLEU's (n_refs, n_sents) layout
    ref_lists = list(map(list, zip(*ref_lists)))

    bleu4 = BLEU(
        max_ngram_order=4,
        lowercase=lowercase,
        tokenize=tokenize,
        smooth_method=smooth_method
    )
    return bleu4.corpus_score(cands, ref_lists).score

prettyParam = {
    "textattack/bert-base-uncased-yelp-polarity": "bertBase"
    ,"meta-llama/Meta-Llama-3.1-8B-Instruct": "llama-8b-In"
}
def create_param_folder(params, base_dir="results", returnNameOnly=True):
    """
    Given a dictionary of parameters, construct a sanitized folder name,
    create that directory under `base_dir`, and return the full path.
    """
    def sanitize(val):
        s = str(val)
        # Replace path separators and spaces with underscores
        for ch in ['/', '\\', ' ']:
            s = s.replace(ch, '_')
        # Handle None explicitly
        if s == "None":
            s = "None"
        return s

    parts = []
    for key, val in params.items():
        if val is not None:
            parts.append(f"{key}-{sanitize(val)}")
    folder_name = "_".join(parts)

    folder_path = os.path.join(base_dir, folder_name)
    if not returnNameOnly:
        os.makedirs(folder_path, exist_ok=True)
    return folder_path



def read_csv_to_list_of_dicts(filepath):
    """
    Reads a CSV file and returns a list of dictionaries.

    Each dictionary corresponds to one row of the CSV (after the header), 
    where the keys are the column names from the first row and the values 
    are the corresponding entries for that row.

    :param filepath: Path to the CSV file to read.
    :return: List[Dict[str, str]] — a list of row-dictionaries.
    """
    with open(filepath, mode='r', encoding='utf-8', newline='') as csvfile:
        reader = csv.DictReader(csvfile)
        # csv.DictReader already uses the first row as keys, and for each
        # subsequent row yields a dict mapping header→value.
        return [dict(row) for row in reader]


# def analize(path, outputFolderPath=None):
    # if outputFolderPath is None:
        # outputFolderPath = os.path.dirname(path)

    # allData = read_csv_to_list_of_dicts(path)
    # correctPredData = [it for it in allData if it["Ground Truth Label"] == it["Orig Pred Label"]]
    # flippedPredData = [it for it in allData if it["Rephrased Pred Label"] != it["Orig Pred Label"]]
    # flippedOrigLabelData = [it for it in allData if it["Rephrased Pred Label"] != it["Ground Truth Label"]]

    
    # # Bleus analysis
    # bleus = [float(it["Bleu"]) for it in allData]
    # bleus.sort()
    # bleus = np.array(bleus)

    # plt.figure()
    # plt.hist(bleus)
    # plt.title(f"Min {bleus.min()} | Max {bleus.max()} \n Mean {bleus.mean()}")
    # plt.savefig(os.path.join(outputFolderPath, "bleuHist.png"))

    # plt.figure()
    # plt.title(f"Min {bleus.min()} | Max {bleus.max()} \n Mean {bleus.mean()}")
    # plt.plot(np.arange(len(bleus)), bleus)
    # plt.savefig(os.path.join(outputFolderPath, "bleuLine.png"))


    # print("---------------------------")
    # print(len(allData))
    # print(len(correctPredData))
    # print(len(flippedPredData))
    # print(len(flippedOrigLabelData))

def analize(paths, names=["Baseline", "Indexed"], colors=["tab:blue", "tab:orange"], outputFolderPath=None):
    if outputFolderPath is None:
        outputFolderPath = os.path.dirname(paths[0])

    os.makedirs(outputFolderPath, exist_ok=True)
    print(f"Output Dir: {outputFolderPath}")

    allDatas = []
    correctPredDatas = []
    flippedPredDatas = []
    flippedOrigLabelDatas = []
    for path in paths:
        allDatas.append(read_csv_to_list_of_dicts(path))
        # these are recomputed later but it runs fast so left themer here. can do meta analysis with them later or something
        allData = allDatas[-1]
        correctPredDatas.append([it for it in allData if it["Ground Truth Label"] == it["Orig Pred Label"]]) # predicted sentiment of originoal review matches ground truth sentiment
        flippedPredDatas.append([it for it in allData if it["Rephrased Pred Label"] != it["Orig Pred Label"]]) # predicted sentiment of rephrased review matches ground truth sentiment
        flippedOrigLabelDatas.append([it for it in allData if it["Rephrased Pred Label"] != it["Ground Truth Label"]])

    
    # Bleus analysis
    corpusBleus = {}
    for allData, c, name in zip(allDatas, colors, names):
        pairs = [(it["Orig Text"], it["Rephrased Text"]) for it in allData]
        b = compute_corpus_bleu4(
            pairs,
            lowercase=True,
            tokenize="13a",
            smooth_method="exp"
            ) 
        corpusBleus[name] = b

    plt.figure()
    for allData, c, name in zip(allDatas, colors, names):
        bleus = [it["Bleu"] for it in allData]
        bleus = [float(it["Bleu"]) for it in allData]
        bleus = np.array(bleus)
        plt.hist(bleus, color=c, alpha=0.7, label=f"{name} | CB {corpusBleus[name]}")
    plt.title("BLeu Dist")
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(outputFolderPath, "bleuHist.png"))

    plt.figure()
    for allData, c, name in zip(allDatas, colors, names):
        bleus = [float(it["Bleu"]) for it in allData]
        bleus.sort()
        bleus = np.array(bleus)
        plt.plot(np.arange(len(bleus)), bleus, color=c, label=f"{name} | CB {corpusBleus[name]}")
    plt.title("BLeu Line")
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(outputFolderPath, "bleuLine.png"))

    plt.figure()
    for allData, c, name in zip(allDatas, colors, names):
        correctPredData = [it for it in allData if it["Ground Truth Label"] == it["Orig Pred Label"]] # predicted sentiment of originoal review matches ground truth sentiment
        plt.bar(name, float(len(correctPredData)) / float(len(allData)), color=c, label=name)
    plt.ylim(0.,1.)
    plt.title(f"Percentage of origional reviews with \nsentiment correctly classified")
    plt.xlabel("Simulation")
    plt.ylabel("Percent")
    plt.tight_layout()
    plt.savefig(os.path.join(outputFolderPath, "origionalCorrectSentimentClassification.png"))

    plt.figure()
    for allData, c, name in zip(allDatas, colors, names):
        flippedGt = [it for it in allData if it["Rephrased Pred Label"] != it["Ground Truth Label"]] # predicted sentiment of originoal review matches ground truth sentiment
        plt.bar(name, float(len(flippedGt)) / float(len(allData)), color=c, label=name)
    plt.ylim(0.,1.)
    plt.title(f"Percentage of rephrased reviews that \nchanged sentiment from ground Truth")
    plt.xlabel("Simulation")
    plt.ylabel("Percent")
    plt.tight_layout()
    plt.savefig(os.path.join(outputFolderPath, "grountTruthRephrasedSentimentChange.png"))

    plt.figure()
    for allData, c, name in zip(allDatas, colors, names):
        correctPredData = [it for it in allData if it["Ground Truth Label"] == it["Orig Pred Label"]] # predicted sentiment of originoal review matches ground truth sentiment
        flippedGt = [it for it in correctPredData if it["Rephrased Pred Label"] != it["Ground Truth Label"]] # predicted sentiment of originoal review matches ground truth sentiment
        plt.bar(name, float(len(flippedGt)) / float(len(allData)), color=c, label=name)
    plt.ylim(0.,1.)
    plt.title(f"Percentage of rephrased reviews that \nchanged sentiment from origional review prediction")
    plt.xlabel("Simulation")
    plt.ylabel("Percent")
    plt.tight_layout()
    plt.savefig(os.path.join(outputFolderPath, "origionalPredictionRephrasedSentimentChange.png"))

    plt.figure()
    for allData, c, name in zip(allDatas, colors, names):
        neutral = [it for it in allData if it["Rephrased Pred Label"] == "Neutral"] # predicted sentiment of originoal review matches ground truth sentiment
        plt.bar(name, float(len(neutral)) / float(len(allData)), color=c, label=name)
    plt.ylim(0.,1.)
    plt.title(f"Percentage of rephrased reviews that \n are neutral")
    plt.xlabel("Simulation")
    plt.ylabel("Percent")
    plt.tight_layout()
    plt.savefig(os.path.join(outputFolderPath, "percentageNeutral.png"))

    
    # plotconfusion matrixs for class prediciont (ground truth vs rephrased predictee) or (origional prediction vs rephrased predictee) to see how things are changeing


    # for allData, correctPredData, flippedOrigLabelData, flippedPredData, name in zip(allDatas, correctPredDatas, flippedOrigLabelDatas, flippedPredDatas, names):
        # print("---------------------------")
        # print(name)
        # print(len(allData))
        # print(len(correctPredData))
        # print(len(flippedPredData))
        # print(len(flippedOrigLabelData))


def getPerfMetric(data):
    # get flip percentage
    # correctPredDatas = [it for it in data if it["Ground Truth Label"] == it["Orig Pred Label"]] # predicted sentiment of originoal review matches ground truth sentiment
    # flippedPredDatas = [it for it in data if it["Rephrased Pred Label"] != it["Orig Pred Label"]] # predicted sentiment of rephrased review matches ground truth sentiment
    flippedOrigLabelData = [it for it in data if it["Rephrased Pred Label"] != it["Ground Truth Label"]]
    return 100.*float(len(flippedOrigLabelData)) / float(len(data))

def getCorpusBleu(data):
    # get total bleu score
    pairs = [(it["Orig Text"], it["Rephrased Text"]) for it in data]
    b = compute_corpus_bleu4(
        pairs,
        lowercase=True,
        tokenize="13a",
        smooth_method="exp"
        ) 
    return b

# allSimPaths = [[simA1, simA2, simA3,..],[simB1, simB2, simB3,..]]
# lineLabels = [label1, label2, ...]
# linePoints = [point1, point2, ...]
# lineColors = [color1, color2, ...]
def analizeMulti(allSimPaths, lineLabels, linePoints, lineColors, outputFolder, pathSuffix="/data.csv", baselinePath=None):
    print(lineLabels)
    fig = plt.figure()
    for simPaths, lineLabel, lineColor in zip(allSimPaths, lineLabels, lineColors):
        flippedPercentage = []
        corpusBleus = []
        for simPath in simPaths:
            # load data
            data = read_csv_to_list_of_dicts(simPath + pathSuffix)
            flippedPercentage.append(getPerfMetric(data))
            corpusBleus.append(getCorpusBleu(data))

        plt.plot(corpusBleus, flippedPercentage, color=lineColor, label=lineLabel, marker=".")
        plt.scatter(corpusBleus[0], flippedPercentage[0], color=lineColor, marker="o")
        plt.scatter(corpusBleus[-1], flippedPercentage[-1], color=lineColor, marker="s")

    # baseline
    if baselinePath is not None:
        print("plotting baseline!!!!!!!!!!!!!!!!!!!!!!!")
        data = read_csv_to_list_of_dicts(baselinePath + pathSuffix)
        baselinePerf = getPerfMetric(data)
        baselinBleu = getCorpusBleu(data)
        plt.scatter([baselinBleu], [baselinePerf], marker="X", color="k", label="Baseline")
        print(baselinePerf, baselinBleu)

    plt.legend()
    plt.ylabel("Flipped Sentiment Percentage")
    plt.xlabel(f"Corpus Bleu\n\n Weight Values: {str(linePoints)}")
    os.makedirs(outputFolder, exist_ok=True)
    plt.tight_layout()
    plt.savefig(os.path.join(outputFolder, "bleuVsFlip.png"))



# analize("/bazhlab/edelanois/llm/16/outputs/nonIndexed/outputNonIndexed.csv")
# analize("/bazhlab/edelanois/llm/16/outputs/indexed/outputIndexed.csv")

# analize(
    # [
    # "/bazhlab/edelanois/llm/16/outputs/dataSplit-train_numberSamples-100_model_name-llama-8b-In_sent_model_name-bertBase_indexApplicationType-standard_layerIdx-2_w-0.35/data.csv"
    # # "/bazhlab/edelanois/llm/16/outputs/nonIndexed/outputNonIndexed.csv"
    # # , "/bazhlab/edelanois/llm/16/outputs/indexed/outputIndexed.csv"
    # ]
    # , names=["Baseline", "Indexed"]
    # , colors=["tab:blue", "tab:orange"]
    # , outputFolderPath="./analysis/"
    # )

# llama 8b
# Individual Sim Analysis
# baselineExecuted = False # be sure to execute baseline w == 0. only one time
# for layerIdx in [1, 3, 5, 10, 15, 20, 25, 30, 31]:
    # for w in [float(w) for w in np.arange(0., 0.5, 0.05)]:
        # if w == 0. and baselineExecuted:
            # continue
        # baselineExecuted = True

        # params = {
            # "dataSplit": "train",
            # "numberSamples": 100,
            # "model_name": "llama-8b-In",
            # "sent_model_name": "bertBase",
            # "indexApplicationType": "standard",
            # "seed":0,
            # "layerIdx": layerIdx,
            # "w": w,
        # }

        # outputFolderPath = create_param_folder(params, base_dir="./llama8b")
        # print(outputFolderPath)
        # analize(
            # [
            # "/bazhlab/edelanois/llm/16/llama8b/dataSplit-train_numberSamples-100_model_name-llama-8b-In_sent_model_name-bertBase_indexApplicationType-standard_seed-0_layerIdx-1_w-0.0/data.csv",
            # f"{outputFolderPath}/data.csv",
            # ]
            # , names=["Baseline", f"Indexed | layer {layerIdx} | weight {w}"]
            # , colors=["tab:blue", "tab:orange"]
            # , outputFolderPath=outputFolderPath + "/plots/"
            # )

# # PAPER
# # llama 8b
# # Group Sim Analysis
# baselineExecuted = False # be sure to execute baseline w == 0. only one time
# allSimPaths = []
# lineLabels = []
# linePoints = None
# lineColors = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red', 'tab:purple', 'tab:brown', 'tab:pink', 'tab:gray', 'k']
# # for layerIdx in [1, 3, 5, 10, 15, 20, 25, 30, 31]:
# for layerIdx in [1, 3, 5, 10]:
    # allSimPaths.append([])
    # lineLabels.append(f"Layer {layerIdx}")
    # linePoints = []
    # # for w in [float(w) for w in np.arange(0., 0.5, 0.05)]:
    # for w in [float(w) for w in np.arange(0.2, 0.5, 0.025)] + [0.]:
        # if w < 0.21: # skip first few wieghts since they have no impact on performance
            # continue

        # if w == 0. and baselineExecuted:
            # continue
        # baselineExecuted = True

        # params = {
            # "dataSplit": "test",
            # "numberSamples": 1000,
            # "model_name": "llama-8b-In",
            # "sent_model_name": "bertBase",
            # "indexApplicationType": "standard",
            # "seed":0,
            # "layerIdx": layerIdx,
            # "w": w,
        # }

        # outputFolderPath = create_param_folder(params, base_dir="./llama8b")
        # if w != 0: # dont add baseline to analysis yet
            # allSimPaths[-1].append(outputFolderPath)
            # linePoints.append(f"{w:.2f}")
# baselinePath = "/bazhlab/edelanois/llm/16/llama8b/dataSplit-train_numberSamples-100_model_name-llama-8b-In_sent_model_name-bertBase_indexApplicationType-standard_seed-0_layerIdx-1_w-0.0/"
# analizeMulti(allSimPaths, lineLabels, linePoints, lineColors, "./llama8b/_plots/", baselinePath=baselinePath)
        

# # llama 70b
# baselineExecuted = False
# for layerIdx in [1, 5, 15, 30, 40, 50, 65, 75, 79]:
    # for w in [float(w) for w in np.arange(0., 0.5, 0.05)]:
        # if w == 0. and baselineExecuted:
            # continue
        # baselineExecuted = True

        # params = {
            # "dataSplit": "train",
            # "numberSamples": 100,
            # "model_name": "llama-70b-In",
            # "sent_model_name": "bertBase",
            # "indexApplicationType": "standard",
            # "seed":0,
            # "layerIdx": layerIdx,
            # "w": w,
        # }

        # outputFolderPath = create_param_folder(params, base_dir="./llama70b")
        # print(outputFolderPath)
        # analize(
            # [
            # "/bazhlab/edelanois/llm/16/llama70b/dataSplit-train_numberSamples-100_model_name-llama-70b-In_sent_model_name-bertBase_indexApplicationType-standard_seed-0_layerIdx-1_w-0.0/data.csv",
            # f"{outputFolderPath}/data.csv",
            # ]
            # , names=["Baseline", f"Indexed | layer {layerIdx} | weight {w}"]
            # , colors=["tab:blue", "tab:orange"]
            # , outputFolderPath=outputFolderPath + "/plots/"
            # )

# PAPER
# llama 70b
# Group Sim Analysis
baselineExecuted = False # be sure to execute baseline w == 0. only one time
allSimPaths = []
lineLabels = []
linePoints = None
lineColors = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red', 'tab:purple', 'tab:brown', 'tab:pink', 'tab:gray', 'k']
# for layerIdx in [1, 5, 15, 30, 40, 50, 65, 75, 79]:
for layerIdx in [1, 5, 15, 30]:
    allSimPaths.append([])
    lineLabels.append(f"Layer {layerIdx}")
    linePoints = []
    # for w in [float(w) for w in np.arange(0., 0.5, 0.05)]:
    for w in [float(w) for w in np.arange(0.3, 0.5, 0.025)] + [0.]:
        if w < 0.26: # skip first few wieghts since they have no impact on performance
            continue

        if w == 0. and baselineExecuted:
            continue
        baselineExecuted = True

        params = {
            "dataSplit": "test",
            "numberSamples": 1000,
            # "model_name": "llama-70b-In",
            "model_name": None, 
            "sent_model_name": "bertBase",
            "indexApplicationType": "standard",
            "seed":0,
            "layerIdx": layerIdx,
            "w": w,
        }

        outputFolderPath = create_param_folder(params, base_dir="./llama70b")
        if w != 0: # dont add baseline to analysis yet
            allSimPaths[-1].append(outputFolderPath)
            linePoints.append(f"{w:.2f}")
baselinePath = "/bazhlab/edelanois/llm/16/llama70b/dataSplit-train_numberSamples-100_sent_model_name-bertBase_indexApplicationType-standard_seed-0_layerIdx-1_w-0.0/"
analizeMulti(allSimPaths, lineLabels, linePoints, lineColors, "./llama70b/_plots/", baselinePath=baselinePath)