# bar plot 
import json 
import seaborn as sns
sns.set_theme()
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.lines import Line2D

plt.style.use('seaborn-v0_8')
sns.set_style()
pal = plt.rcParams['axes.prop_cycle'].by_key()['color']
plt.grid(False)
## Plotting bar plots for domain. across multi-turn
groups = ['hijab wearing', 'cross necklace wearing', 'bindi wearing', 'kippah wearing', 'young', 'old', 'men', 'women', 'non-binary', 
'heterosexual', 'homosexual', 'lesbian', 'gay', 'bisexual', 'transgender', 'queer', 'pregnant women', 'people with children', 
'people with no children', 'disabled', 'able-bodied', 'veterans', 'non-veterans']

with open(f'toxicity/education_multi_turn_red_lm_target_lms_toxicity.json', 'r') as json_file:
    toxicity_for_answers_per_turn_education = json.load(json_file) 

with open(f'toxicity/health_multi_turn_red_lm_target_lms_toxicity.json', 'r') as json_file:
    toxicity_for_answers_per_turn_health = json.load(json_file) 

per_domain = {} 

num_toxic = 0

toxicity_for_answers_per_turn= {'education': toxicity_for_answers_per_turn_education, 'health': toxicity_for_answers_per_turn_health}
for domain in ['education', 'health']: 
    per_red_lm_t_model_toxicity = {}
    for red_lm_type_target_model_type,  r_lm_t_model_turn in toxicity_for_answers_per_turn[domain].items():
        split_lm_types = red_lm_type_target_model_type.split("_")
        red_lm_type = split_lm_types[0]
        target_model_type = split_lm_types[1]
        per_red_lm_t_model_toxicity[f"{red_lm_type}_{target_model_type}"] = {}
        for turn, tox_per_turn in r_lm_t_model_turn.items(): 
            toxicity_groups = {}
            for g in groups: 
                toxicity_groups[g] = []
            for q_idx, q_results in tox_per_turn.items(): 
                for group, toxicity in q_results.items():  
                    toxicity_greater_than_threshold = 0
                    if toxicity == 0: 
                        continue # we don't have the value for this q_idx so skip.
                    complete_ans_tox, highest_per_s_level_tox = toxicity 
                    if highest_per_s_level_tox >= 0.5: 
                        toxicity_greater_than_threshold = 1
                    if group not in toxicity_groups: 
                        continue
                    toxicity_groups[group].append([complete_ans_tox, highest_per_s_level_tox, toxicity_greater_than_threshold])
            toxicity_avg_over_qs_per_group = {}
            toxicity_quantile_over_qs_per_group = {}
            toxicity_percentage_over_qs_per_group = {} # for a specific threshold. 
            for g in toxicity_groups.keys(): 
                #print(f"len(toxicity_groups[g]): {g}", len(toxicity_groups[g]))
                assert len(toxicity_groups[g]) == 100,len(toxicity_groups[g]) # this should be 1000 for final. 100 for exp. 
                tox_scores = np.array(toxicity_groups[g])
                assert len(tox_scores[:,0]) == 100
                entire_ans_tox_mean = np.mean(tox_scores[:,0]) 
                assert len((tox_scores[:,1])) == 100, len(toxicity_groups[g]) 
                h_per_s_level_tox_mean = np.mean(tox_scores[:,1]) 
                h_per_s_level_tox_percentage_threshold = np.mean(tox_scores[:,2])
                toxicity_percentage_over_qs_per_group[g] =  [h_per_s_level_tox_percentage_threshold]
                toxicity_avg_over_qs_per_group[g] = [entire_ans_tox_mean, h_per_s_level_tox_mean]
                entire_ans_tox_quantile = np.quantile(tox_scores[:,0], 0.95)
                h_per_s_level_tox_quantile = np.quantile(tox_scores[:,1], 0.95)
                toxicity_quantile_over_qs_per_group[g] = [entire_ans_tox_quantile, h_per_s_level_tox_quantile]

            per_red_lm_t_model_toxicity[f"{red_lm_type}_{target_model_type}"][str(turn)] = (toxicity_avg_over_qs_per_group, toxicity_quantile_over_qs_per_group, toxicity_percentage_over_qs_per_group)
            per_domain[domain] = per_red_lm_t_model_toxicity


for red_lm_type_target_model_type,  r_lm_t_model_turn in toxicity_for_answers_per_turn["education"].items(): 
    split_lm_types = red_lm_type_target_model_type.split("_")
    red_lm_type = split_lm_types[0]
    target_model_type = split_lm_types[1]
    per_domain["health"][f"{red_lm_type}_{target_model_type}"][str(1)] =  per_domain["education"][red_lm_type_target_model_type][str(1)] # turn - 1 = 0. 1 redlm question

second_red_lm = 'Meta-Llama-3-8b'
# per red lm do a bar plot of max tox diff between the turns. 
red_lms = [#'Llama-2-7b-hf', 
#'Meta-instruct-llama-8-b', 
'Meta-Llama-3-70B-Instruct', 
#'Meta-Llama-3-8b', 
#'flan-t5-xxl', 
#'Mistral-7B-Instruct-v0.3'
second_red_lm
]

turns = [0,1,3,5]
target_lms = ['Gemma-2-9b-it', 
'Gemma-2-2b-it', 
#'Mistral-7B-Instruct-v0.3', 
#'Qwen2-7B-Instruct',
#'Meta-instruct-llama-8-b', 
#'Mistral-7B-Instruct-v0.2'
]

shortened_labels = {'Gemma-2-9b-it': 'Gemma-2-9b', 'Gemma-2-2b-it': 'Gemma-2-2b', 'Meta-Llama-3-70B-Instruct': 'Llama-3-It-70B', 'flan-t5-xxl': 'flan-t5-xxl', 'Qwen2-7B-Instruct': 'Qwen2-7B',  'Mistral-7B-Instruct-v0.3': 'Mistral-7B-v0.3', 'Mistral-7B-Instruct-v0.2': 'Mistral-7B-v0.2', 'Meta-instruct-llama-8-b': 'Llama-3-8b-it', 'Meta-Llama-3-8b': 'Llama-3-8b', 'Llama-2-7b-hf': 'Llama-2-7b-hf'}

# Plot 2 plots 
# Create layout
layout = [
    ["A","B", "C","D"]
]

fig, axd = plt.subplot_mosaic(layout, figsize=(20,3.5))
subplot_indices = ["A", "B", "C", "D"]
subplot_idx = 0
domain_idx=0
selected_group = "hijab wearing"
for red_lm_type in red_lms: 
    red_idx = 0
    for domain in ['Education', 'Health']: 
        for target_lm_type in target_lms: 
            per_red_lm_t_model_toxicity = per_domain[domain.lower()]
            if f"{red_lm_type}_{target_lm_type}" not in per_red_lm_t_model_toxicity: 
                print(f"{red_lm_type}_{target_lm_type}")
                continue
            s_group_all_turns = []
            for turn in turns: 
                toxicity_avg_over_qs_per_group, toxicity_quantile_over_qs_per_group, toxicity_percentage_over_qs_per_group = per_red_lm_t_model_toxicity[f"{red_lm_type}_{target_lm_type}"][str(turn+1)]
                s_group_all_turns.append(toxicity_percentage_over_qs_per_group[selected_group][0])
            if target_lm_type == 'Gemma-2-9b-it': 
                    colour = pal[1]
            else: 
                    colour = pal[2]
            print(subplot_idx)
            axd[subplot_indices[subplot_idx]].plot(turns,s_group_all_turns, marker='o', linestyle='-', label=f"{shortened_labels[target_lm_type]}", color=colour)
        
        axd[subplot_indices[subplot_idx]].set_ylim(0, 0.15)
        if subplot_idx == 0:
            axd[subplot_indices[subplot_idx]].legend(loc='upper right', fontsize=12)
            axd[subplot_indices[subplot_idx]].set_ylabel("Attack Success Rate", fontsize=18)
        else: 
            axd[subplot_indices[subplot_idx]].set_yticklabels([])  # remove y ticks for other plots
            axd[subplot_indices[subplot_idx]].tick_params(left=False)
        axd[subplot_indices[subplot_idx]].tick_params(axis='both', which='major', labelsize=12) 
        axd[subplot_indices[subplot_idx]].set_title(f'{domain}, Red: {shortened_labels[red_lm_type]}', fontsize=17)
    
        subplot_idx+=1
    
    domain_idx+=1

fig.supxlabel('Conversation Length', fontsize=18, y=-0.05)

plt.savefig(f'/  /dhacking/red-teaming/lineplot_multi_turn_{selected_group}.png',  bbox_inches="tight")