import matplotlib.cm
from sentence_transformers import SentenceTransformer
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.manifold import TSNE
import csv
from tqdm import tqdm
from pathlib import Path
import json
import matplotlib.colors as mcolors
from scipy.stats import sem
import re


def main_cluster_programs(args):
    cluster_model = args.cluster_model
    output_dir = Path(args.output_dir)
    program_name_column = args.program_name_column
    asr_column = args.asr_column # column containing the ASR estimate

    output_dir.mkdir(parents=True, exist_ok=True)

    # analyse clustering patterns for attacks with ASR > 0.4
    asr_threshold=0.4
    df = load_data(
        args, 
        cluster_model, 
        program_name_column, 
        asr_column, 
        asr_threshold
    )

    plot_t_sne_by_group(
        output_dir, 
        df,
        title=f"T-sne projection of CodeBERT embeddings of Jailbreak Attacks >= {asr_threshold} ASR",
        output_file_name=f'cluster_by_group.{args.plot_extension}',
        args=args
    )


    # Analyse primitive counts by ASR for all attacks
    df = load_data(
        args, 
        cluster_model, 
        program_name_column, 
        asr_column, 
        asr_threshold=-1
    )
    analyse_primitive_counts(
        df, 
        output_dir,
        args
    )


    # analyse primitive distributions in attacks with ASR>0.8
    asr_threshold = 0.8
    df = load_data(
        args, 
        cluster_model, 
        program_name_column, 
        asr_column, 
        asr_threshold
    )
    analyse_primitive_distribution(
        df, 
        output_dir,
        asr_threshold,
        args
    )


def plot_t_sne_by_group(output_dir, df, title, output_file_name, args):
    points_per_group = {}
    for ix, row in df.iterrows():
        if row['groups'] not in points_per_group:
            points_per_group[row['groups']] = {
                "d1": [],
                "d2": [],
            }
        points_per_group[row['groups']]["d1"].append(row['d1'])
        points_per_group[row['groups']]["d2"].append(row['d2'])

    groups = sorted(points_per_group.keys())
    group_display_names = [
        f"{g}({len(points_per_group[g]['d1'])})" for g in groups
    ]
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(30, 8))
    
    for g_ix, g in enumerate(groups):
        ax1.scatter(
            points_per_group[g]['d1'],
            points_per_group[g]['d2'],
            c=args.color_palette[g_ix],
            label=group_display_names[g_ix]
        )
    ax1.legend(loc='upper left', bbox_to_anchor=(1.02, 1), borderaxespad=0.)

    # df["colors"] = [matplotlib.cm.get_cmap("Reds")(v) for v in df['asr']]
    norm = mcolors.Normalize(vmin=0, vmax=1)

    sc = ax2.scatter(
        df['d1'],
        df['d2'],
        c=df["asr"],
        cmap = matplotlib.cm.get_cmap("Reds"),
        norm=norm
    )

    ax1.set_aspect('equal')
    ax2.set_aspect('equal')
    
    ax1.set_title('Grouped Scatter Plot')
    ax2.set_title('ASR Scatter Plot')
    
    plt.colorbar(sc, ax=ax2, pad=0.01, fraction=0.046)
    
    # ax2.legend(group_display_names, loc='upper left', bbox_to_anchor=(1.02, 1), borderaxespad=0.)
    plt.tight_layout()
    
    fig.suptitle(title, fontsize=16)
    plt.subplots_adjust(top=0.85)
    
    plt.savefig(output_dir / output_file_name, bbox_inches='tight')


def load_data(args, cluster_model, program_name_column, asr_column, asr_threshold):
    programs = []
    asrs = []
    groups = []

    for input_file in args.input_files:
        with open(input_file, 'r') as f:
            input_reader = csv.DictReader(f)
            for row in tqdm(list(input_reader)):
                asr = float(row[asr_column])
                if asr<asr_threshold:
                    continue

                program = row[program_name_column]
                programs.append(program)
                asrs.append(asr)
                groups.append(f"{row['syn_algo']}/{row['syn_examples_type']}/{row['synthesis_target']}")

    embedder = SentenceTransformer(cluster_model)
    embeddings = embedder.encode(programs, convert_to_numpy=True, batch_size=1024, show_progress_bar=True)

    tsne = TSNE(n_components=2, learning_rate='auto',
                  init='random', perplexity=5, random_state=42)
    tsne_results = tsne.fit_transform(np.array(embeddings))

    df = pd.DataFrame()
    df['program'] = programs
    df['d1'] = tsne_results[:,0]
    df['d2'] = tsne_results[:,1]
    df['asr'] = asrs
    df['groups'] = groups
    return df


def analyse_primitive_counts(df, output_dir, args):
    # length analysis
    # do longer compositions perform better?
    length_asr_dict = {}
    for idx, row in df.iterrows():
        p = row['program']
        asr = row['asr']
        if ".then(" not in p:
            length = 1
        else: 
            length = len(p.split(".then("))

        if length not in length_asr_dict:
            length_asr_dict[length] = []

        length_asr_dict[length].append(asr)

    # dump this to a file
    json.dump(length_asr_dict, open(output_dir / 'lengths.json', 'w'))

    avg_asrs = []
    lengths = []
    stem_asrs = []
    total_attack_count = 0
    for l in length_asr_dict.keys():
        lengths.append(l)
        avg_asrs.append(np.mean(length_asr_dict[l]))
        stem_asrs.append(sem(length_asr_dict[l]))
        total_attack_count+= len(length_asr_dict[l])

    plt.figure(figsize=(10, 10))
    plt.bar(lengths, avg_asrs, yerr=stem_asrs, capsize=5, align='center')
    plt.xlabel("Number of Primitives in Attack")
    plt.ylabel("Mean ± Stem ASR")
    plt.title(f"ASR vs Number of Primitives for {total_attack_count} Attacks")
    plt.grid(axis='y')
    plt.savefig(output_dir / f'asr_vs_length.{args.plot_extension}')
    plt.close()



def analyse_primitive_distribution(df,  output_dir, asr_threshold, args):
    primitives_per_group = {}
    for idx, row in df.iterrows():
        if row['groups'] not in primitives_per_group:
            primitives_per_group[row['groups']] = []
        primitives_per_group[row['groups']].extend(get_primitives(row['program']))
    for g in primitives_per_group.keys():
        distinct_primitive_counts = {}
        for p in primitives_per_group[g]:
            if p not in distinct_primitive_counts:
                distinct_primitive_counts[p]=0
            distinct_primitive_counts[p] += 1
        sorted_primitives_and_counts = sorted(distinct_primitive_counts.items(), key=lambda t:t[1], reverse=True)
        primitives = [t[0] for t in sorted_primitives_and_counts]
        counts = [t[1] for t in sorted_primitives_and_counts]
        gname = g.replace('/', '_')
        plt.figure(figsize=(10,10))
        plt.bar(range(len(primitives)), counts)
        plt.xticks(range(len(primitives)), labels=primitives, rotation=90)
        plt.title(f"Freq. of Primitives in Attacks with ASR>{asr_threshold} \n {gname}")
        plt.tight_layout()
        plt.savefig(output_dir / f"primitive_distribution_{gname}.pdf")
    



def get_primitives(program):
    primitives = []
    for p in program.split(".then("):
        p = re.sub(
            "\(.*\)\)?", "", p
        )
        p=p.replace("Decorator","").strip()
        primitives.append(p)

    return primitives