import json
import shutil
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional

import matplotlib.colors as mcolors
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib import colormaps  # type: ignore
from matplotlib import pyplot as plt
from scipy.stats import spearmanr
from tqdm import tqdm

from lm_understanding import metrics
from lm_understanding.question_template import TemplateDataset

viridis = colormaps['viridis']
blue_color = viridis(.3)
yellow_color = viridis(.7)


TASKS = ['aita', 'moral_dilemmas', 'harmful_requests', 'misinformation', 'situational_awareness', 'llm_goals', 'advanced_ai_risk', 'power_seeking', 'self_preservation', 'strange_hypotheticals', 'hiring_decisions', 'sycophancy']
BASELINES = [
    'predict_average',
    'nearest_neighbor',
    'nearest_neighbor_three',
    'linear_regression',
    # 'distill_0_shot',
    # 'llm_15_shot',
    ]
# MODELS = ['flan-alpaca-gpt4-xl', 'flan-alpaca-xl', 'flan-alpaca-xxl', 'flan-alpaca-ul2', 'vicuna-7B-1.1-HF', 'vicuna-13B-1.1-HF', 'stable-vicuna-13B-HF', 'vicuna-13B-1.1-HF-cot']
MODELS = [
    'flan-alpaca-gpt4-xl',
    'vicuna-7b-v1.3',
    # 'vicuna-13b-v1.3',
    # 'stable-vicuna-13B-HF',
    # 'Llama-2-7b-chat-hf',
    ]
MODELS_WITH_LOGITS = [m for m in MODELS if '-cot' not in m]
SAVE_DIR = Path('plots')
BASELINE_METRICS = [
    'log_score',
    'tv_distance',
    'normalized_log_score',
    'kl_divergence',
    'accuracy',
    ]
DISTRIBUTIONAL_SHIFT = 'new_combinations'

def load_baseline_results_df() -> pd.DataFrame:
    records = []
    for model_dir in tqdm([m for m in Path('baseline_results').iterdir() if m.name in MODELS], desc='Loading baseline results'):
        for task_dir in model_dir.iterdir():
            if not task_dir.name in TASKS:
                continue
            for shift_dir in task_dir.iterdir():
                for results_file in shift_dir.iterdir():
                    if not results_file.suffix == '.json':
                        continue
                    baseline_results_summary = json.loads(results_file.read_text())
                    records.append(dict(
                        model=model_dir.name,
                        task=task_dir.name,
                        distributional_shift=shift_dir.name,
                        baseline=results_file.stem,
                        **baseline_results_summary['results'],
                        templates_evaluated=baseline_results_summary['templates_evaluated']
                    ))
    return pd.DataFrame.from_records(records)



def load_model_behavior_df() -> pd.DataFrame:
    dfs = []
    baseline_results = load_baseline_results_df()
    for model_dir in tqdm([m for m in Path('model_behavior_results').iterdir() if m.name in MODELS], desc=f'Loading Model Behavior'):
        if not model_dir.name in MODELS:
            continue
        for task_dir in model_dir.iterdir():
            if not task_dir.name in TASKS:
                continue
            baseline_results_for_task = baseline_results[baseline_results.task == task_dir.name]
            baseline_results_for_task = baseline_results_for_task[baseline_results_for_task.baseline == 'predict_average']
            selected_templates = baseline_results_for_task[baseline_results_for_task.model == model_dir.name].templates_evaluated.tolist()
            if not selected_templates:
                continue
            selected_templates = selected_templates[0]
            if model_dir.name in ['flan-alpaca-gpt-xl', 'vicuna-7b-v1.3']:
                assert len(selected_templates) == 15
            for shift_dir in task_dir.iterdir():
                model_behavior_dir = shift_dir / 'final_behavior'
                if not model_behavior_dir.exists():
                    continue
                n_templates = 0
                for template_dir in model_behavior_dir.iterdir():
                    if not template_dir.is_dir() or template_dir.name == 'explanations':
                        continue
                    if not template_dir.name in selected_templates:
                        continue
                    df = pd.read_csv(template_dir / 'model_behavior.csv')
                    df['model'] = model_dir.name
                    df['distributional_shift'] = shift_dir.name
                    df['task'] = task_dir.name
                    dfs.append(df)
                    n_templates += 1
                assert n_templates == 15
    return pd.concat(dfs)


@dataclass
class DatasetConfig:
    model: str
    task: str
    shift: str

    def __hash__(self) -> int:
        return (self.model, self.task, self.shift).__hash__()


def load_datasets() -> Dict[DatasetConfig, TemplateDataset]:
    datasets = dict()
    for model in tqdm(MODELS, desc="Loading Datasets"):
        for shift in [DISTRIBUTIONAL_SHIFT]:
            for task in TASKS:
                path = Path(f'datasets/{model}/{task}/{shift}/final_dataset')
                if not path.exists():
                    continue
                datasets[DatasetConfig(model, task, shift)] = TemplateDataset.load(path)
    return datasets


class DatasetDataRetriever:
    def __init__(self) -> None:
        self.datasets = load_datasets()
    
    @property
    def dataset_configs(self):
        return list(self.datasets.keys())

    def average_words_per_question(self, dataset_config: DatasetConfig):
        dataset = self.datasets[dataset_config]
        return [t.average_words_per_question for t in dataset.template_tasks]

    def fraction_words_changed(self, dataset_config: DatasetConfig):
        dataset = self.datasets[dataset_config]
        return [t.fraction_words_changed_by_substitutions for t in dataset.template_tasks]

    def dataset_metadata(self, dataset_config: DatasetConfig) -> Optional[Dict]:
        dataset = self.datasets.get(dataset_config)
        return dataset.metadata if dataset else None

    def initial_filter_kept_fraction(self, dataset_config: DatasetConfig) -> Optional[float]:
        metadata = self.dataset_metadata(dataset_config)
        if not metadata:
            return None
        return metadata['initial_filter']['kept_fraction']

    def total_templates(self, dataset_config: DatasetConfig):
        metadata = self.dataset_metadata(dataset_config)
        if not metadata:
            return None
        return len(metadata['initial_filter']['template_etvds'])
    
    def template_summaries(self, dataset_config: DatasetConfig) -> Optional[List[Dict]]:
        metadata = self.dataset_metadata(dataset_config)
        if not metadata:
            return None
        return metadata['baseline_filtering_info']['template_baseline_results']['template_summaries']

    def template_baseline_filtering_scores(self, dataset_config: DatasetConfig) -> Optional[List[float]]:
        template_summaries = self.template_summaries(dataset_config)
        if not template_summaries:
            return None
        return [s['baseline_results'][dataset_config.model]['LinearRegression']['kl_divergence']['test'] for s in template_summaries]

    def fraction_retained_at_threshold(self, dataset_config: DatasetConfig, threshold: float = 0.1):
        scores = self.template_baseline_filtering_scores(dataset_config)
        if not scores:
            return None
        return np.mean([s > threshold for s in scores])

    def best_retained_performance(self, dataset_config: DatasetConfig, kept: int = 15):
        scores = self.template_baseline_filtering_scores(dataset_config)
        if not scores:
            return None
        return sorted(scores)[-(kept - 1)]

    def sample_question(self, dataset_config: DatasetConfig) -> Optional[str]:
        dataset = self.datasets.get(dataset_config)
        if dataset is None:
            return None
        return dataset.template_tasks[0].questions['test'].question_text.tolist()[0]


def plot_dataset_info(datasets: DatasetDataRetriever, dataset_config: DatasetConfig):
    sns.histplot(datasets.fraction_words_changed(dataset_config), stat='probability')
    plt.xlim([0, 1])
    plt.xlabel('Fraction of Template Words Changed')
    plt.title(dataset_config.task)
    plt.savefig(SAVE_DIR / 'datasets' / dataset_config.task / dataset_config.model / dataset_config.shift /'template_words_changed.png', bbox_inches="tight")
    plt.clf()
    sns.histplot(datasets.average_words_per_question(dataset_config), stat='probability')
    plt.title(f'Question Lengths {dataset_config.task}')
    plt.xlabel('Words')
    plt.savefig(SAVE_DIR / 'datasets' / dataset_config.task / dataset_config.model / dataset_config.shift / 'question_lengths.png', bbox_inches="tight")
    plt.clf()


def plot_all_dataset_info(dataset_data: DatasetDataRetriever):
    for dataset_config in tqdm(dataset_data.dataset_configs, desc='Plotting Dataset Info'):
        plot_dataset_info(dataset_data, dataset_config)


def filter_by_model(df: pd.DataFrame, model: str) -> pd.DataFrame:
    return df[df.model == model]


def filter_by_shift(df: pd.DataFrame, distributional_shift: str) -> pd.DataFrame:
    return df[df.distributional_shift == distributional_shift]


def filter_by_task(df: pd.DataFrame, task: str) -> pd.DataFrame:
    return df[df.task == task]


def filter_by_baseline(df: pd.DataFrame, baseline: str) -> pd.DataFrame:
    return df[df.baseline == baseline]


class BehaviorDataRetriever:
    def __init__(self) -> None:
        self.behavior = load_model_behavior_df()
    
    def positive_answer_probs(self, model: str, task: Optional[str] = None) -> np.ndarray:
        data = filter_by_model(self.behavior, model)
        if task and task != 'overall':
            data = filter_by_task(data, task)
        return data.answer_prob.to_numpy()

    def combined_answer_probs(self, model: str, task: Optional[str]) -> np.ndarray:
        data = filter_by_model(self.behavior, model)
        if task and task != 'overall':
            data = filter_by_task(data, task)
        return data.total_answer_prob.to_numpy()

    def mean_combined_answer_probs_by_template(self, model: str, task: str) -> np.ndarray:
        data = filter_by_model(self.behavior, model)
        data = filter_by_task(data, task)
        mean_combined_answer_probs = []
        for template_id in data.template_id.unique():
            template_data = data[data.template_id == template_id]
            mean_combined_answer_probs.append(template_data.combined_answer_prob.mean())
        return np.array(mean_combined_answer_probs)

    def etvds_by_template(self, model: str, task: str) -> np.ndarray:
        data = filter_by_model(self.behavior, model)
        data = filter_by_task(data, task)
        etvds = []
        for template_id in data.template_id.unique():
            template_data = data[data.template_id == template_id]
            etvds.append(metrics.etvd(template_data.answer_prob))
        return np.array(etvds)


def plot_combined_answer_prob_hist(behavior: BehaviorDataRetriever, model: str, task: str, shift: str):
    combined_answer_probs = behavior.combined_answer_probs(model, task)
    if not len(combined_answer_probs):
        return
    sns.histplot(combined_answer_probs, stat='probability')
    plt.xlim([0, 1])
    plt.title(f'{task}, {model}')
    plt.xlabel("Prob of 'Yes' or 'No' Answer")
    plt.savefig(SAVE_DIR / 'tasks' / task / f'combined_answer_probs_{model}.png', bbox_inches="tight")
    plt.clf()


def plot_positive_answer_prob_hist(behavior: BehaviorDataRetriever, model: str, task: str, shift: str):
    positive_answer_probs = behavior.positive_answer_probs(model, task)
    if not len(positive_answer_probs):
        return
    sns.histplot(positive_answer_probs, stat='probability')
    plt.xlim([0, 1])
    plt.xlabel('Prob of Yes Answer')
    plt.title(f'{task}, {model}')
    plt.savefig(SAVE_DIR / 'tasks' / task / f'given_answer_probs_{model}.png', bbox_inches="tight")
    plt.clf()


def plot_combined_answer_prob_bar(behavior: BehaviorDataRetriever, model: str, shift: str):
    task_average_answer_probs = [behavior.combined_answer_probs(model, task) for task in TASKS]
    task_average_answer_probs = [probs.mean() if len(probs) else None for probs in task_average_answer_probs]
    sns.barplot(
        pd.DataFrame.from_dict(
            {
                'Average Combined Answer Probs of Yes and No': task_average_answer_probs,
                'Task': TASKS
            }).sort_values('Average Combined Answer Probs of Yes and No',
            ascending=False),
            y='Task', x='Average Combined Answer Probs of Yes and No',
            color='steelblue',
            orient='h'
        )
    plt.title(model)
    plt.savefig(SAVE_DIR / 'models' / model / 'combined_answer_probs.png', bbox_inches="tight")
    plt.clf()


def plot_model_behavior_summary(behavior: BehaviorDataRetriever, model: str):
    for task in tqdm(TASKS + ['overall'], desc=f'Plotting Behavior Summaries: {model}'):
        plot_combined_answer_prob_hist(behavior, model, task, DISTRIBUTIONAL_SHIFT)
        plot_positive_answer_prob_hist(behavior, model, task, DISTRIBUTIONAL_SHIFT)
    plot_combined_answer_prob_bar(behavior, model, DISTRIBUTIONAL_SHIFT)


def plot_behavior_summary(behavior: BehaviorDataRetriever):
    for model in tqdm(MODELS, desc='Behavior Summary'):
        plot_model_behavior_summary(behavior, model)


class ResultsDataRetriever:
    def __init__(self) -> None:
        self.results = load_baseline_results_df()


def plot_baseline_results(metric: str):
    print(f'Plotting Baseline Results: {metric}')
    df = load_baseline_results_df() 
    df = df[df.distributional_shift == DISTRIBUTIONAL_SHIFT]

    # model vs. baselines
    records = {}
    for baseline in BASELINES:
        model_results = {}
        for model in MODELS:
            matching = df[df.model == model]
            matching = matching[matching.baseline == baseline]
            performance_data = matching.loc[:, metric].to_numpy()
            performance = performance_data.mean() if len(performance_data) else None
            if performance is not None:
                model_results[model] = performance
        records[baseline] = model_results
    results_df = pd.DataFrame.from_dict(records).T
    sns.heatmap(results_df, annot=True, cbar_kws={'label': metric})
    plt.title('Average Performance Across Tasks')
    plt.savefig(SAVE_DIR / f'overall_baseline_results_{metric}.png', bbox_inches="tight")
    plt.clf()

    for task in TASKS:
        matching_task = df[df.task == task]
        records = {}
        for baseline in BASELINES:
            model_results = {}
            for model in MODELS:
                matching = matching_task[matching_task.model == model]
                matching = matching[matching.baseline == baseline]
                performance_data = matching.loc[:, metric].to_numpy()
                performance = performance_data.mean() if len(performance_data) else None
                if performance is not None:
                    model_results[model] = performance
            records[baseline] = model_results
        results_df = pd.DataFrame.from_dict(records).T
        if not results_df.empty:
            sns.heatmap(results_df, annot=True, cbar_kws={'label': metric})
            plt.title(task)
            plt.savefig(SAVE_DIR / 'tasks' / task / f'baseline_results_{metric}.png', bbox_inches="tight")
            plt.clf()

    for model in MODELS:
        matching_model = df[df.model == model]
        records = {}
        for task in TASKS:
            task_results = {}
            for baseline in BASELINES:
                matching = matching_model[matching_model.baseline == baseline]
                matching = matching[matching.task == task]
                performance_data = matching.loc[:, metric].to_numpy()
                performance = performance_data.mean() if len(performance_data) else None
                if performance is not None:
                    task_results[baseline] = performance
            records[task] = task_results
        results_df = pd.DataFrame.from_dict(records).T
        if not results_df.empty:
            sns.heatmap(results_df, annot=True, cbar_kws={'label': metric})
            plt.title(model)
            plt.savefig(SAVE_DIR / 'models' / model / f'baseline_results_{metric}.png', bbox_inches="tight")
            plt.clf()


def reset_dirs(dataset_data: DatasetDataRetriever):
    shutil.rmtree(SAVE_DIR, ignore_errors=True)
    SAVE_DIR.mkdir(parents=True)
    for task in TASKS + ['overall']:
        (SAVE_DIR / 'tasks' / task).mkdir(parents=True)
    for dataset_config in dataset_data.dataset_configs:
        (SAVE_DIR / 'datasets' / dataset_config.task / dataset_config.model / dataset_config.shift).mkdir(parents=True)
    for model in MODELS:
        (SAVE_DIR / 'models' / model).mkdir(parents=True)


def dataset_creation_row(config: DatasetConfig, dataset_data: DatasetDataRetriever, behavior: BehaviorDataRetriever):
    return dict(
        task=config.task,
        sample_question=dataset_data.sample_question(config),
        total_templates=dataset_data.total_templates(config),
        initial_filter_fraction=dataset_data.initial_filter_kept_fraction(config),
        baseline_filter_fraction=dataset_data.fraction_retained_at_threshold(config),
        best_retained_linear_regression_performance=dataset_data.best_retained_performance(config),
        total_yes_no_answer_prob=behavior.combined_answer_probs(config.model, config.task).mean(),
        mean_yes_answer_prob=behavior.positive_answer_probs(config.model, config.task).mean(),
        yes_answer_prob_etvd=metrics.etvd(behavior.positive_answer_probs(config.model, config.task)) # type: ignore
    )


def dataset_creation_table(dataset_data: DatasetDataRetriever, behavior: BehaviorDataRetriever):
    for model in tqdm(MODELS, desc='Dataset creation table'):
        rows = [dataset_creation_row(DatasetConfig(model, task, DISTRIBUTIONAL_SHIFT), dataset_data, behavior) for task in TASKS]
        pd.DataFrame.from_records(rows).to_csv(SAVE_DIR / 'models' / f'dataset_info_{model}.csv', index=False)


def create_behavior_task_summary_plot(behavior: BehaviorDataRetriever):
    models = ['flan-alpaca-gpt4-xl', 'vicuna-7b-v1.3']
    data = list()
    for i, task in tqdm(list(enumerate(TASKS)), 'Create task summary plot'):
        for j, model in enumerate(models):
            combined_answer_probs = behavior.combined_answer_probs(model, task)
            etvds = behavior.etvds_by_template(model, task)
            positive_answer_probs = behavior.positive_answer_probs(model, task)
            assert len(etvds) == 15
            assert len(combined_answer_probs) == (15 * 550)
            assert len(positive_answer_probs) == (15 * 550)
            data.append(dict(
                model=model,
                task=task,
                etvd=etvds.mean(),
                etvds=etvds,
                combined_answer_prob=combined_answer_probs.mean(),
                combined_answer_probs=combined_answer_probs,
                answer_probs=positive_answer_probs
            ))
    df = pd.DataFrame.from_records(data)
    
    etvds_df = df.explode('etvds').melt(id_vars=['model', 'task'], value_vars=['etvds'], 
                                            var_name='metric', value_name='value')
    combined_answer_probs_df = df.explode('combined_answer_probs').melt(id_vars=['model', 'task'], value_vars=['combined_answer_probs'],
                                            var_name='metric', value_name='value')
    answer_probs_df = df.explode('answer_probs').melt(id_vars=['model', 'task'], value_vars=['answer_probs'],
                                            var_name='metric', value_name='value')
    dfs = dict(
        etvds=etvds_df,
        combined_answer_probs=combined_answer_probs_df,
        answer_probs=answer_probs_df
    )

    fig, axs = plt.subplots(3, 2, figsize=(10, 12))

    for i, model in enumerate(models):
        for j, metric in enumerate(['etvds', 'combined_answer_probs', 'answer_probs']):
            ax = axs[j, i]            
            df = dfs[metric]
            filtered_df = df[(df['model'] == model)].copy()
            filtered_df['value'] = pd.to_numeric(filtered_df['value'])
            task_order = filtered_df.groupby('task').mean(numeric_only=True).sort_values(by='value', ascending=False).index
            sns.violinplot(data=filtered_df, x='value', y='task', ax=ax, color=blue_color, inner=None, order=task_order, linewidth=0) # type: ignore
            ax.set_ylabel('')
            
            if metric == 'etvds':
                ax.set_title(f'{model}')
                ax.set_xlabel('Within Template Avg. Pairwise TVDist')
                ax.set_xlim(0, 0.5)
            elif metric == 'answer_probs':
                ax.set_xlabel("Probability of 'Yes'")
                ax.set_xlim(0, 1)
                ax.set_xticks(np.linspace(0, 1, 5))
            else:
                ax.set_xlabel("Total 'Yes' or 'No' Probability")
                ax.set_xlim(0, 1)
                ax.set_xticks(np.linspace(0, 1, 5))

    plt.subplots_adjust(hspace=0.3, wspace=0.6)
    plt.savefig(SAVE_DIR / 'overall_model_behavior_violin.png', bbox_inches="tight")
    plt.clf()

    fig, axs = plt.subplots(1, 3, figsize=(10, 3))

    for j, metric in enumerate(['combined_answer_probs', 'answer_probs', 'etvds']):
        ax = axs[j]            
        df = dfs[metric]
        df['value'] = pd.to_numeric(df['value'])

        sns.violinplot(data=df, x='value', y='task', hue='model', ax=ax, palette=[blue_color, yellow_color], inner=None, linewidth=0, legend=False, split=True) # type: ignore
        ax.set_ylabel('')
        ax.legend_.remove()
        
        if metric == 'etvds':
            ax.set_xlabel('Within Template Avg. Pairwise TVDist')
            ax.set_xlim(0, 0.5)
            ax.set_yticks([])
        elif metric == 'answer_probs':
            ax.set_xlabel("Probability of 'Yes'")
            ax.set_xlim(0, 1)
            ax.set_xticks(np.linspace(0, 1, 5))
            ax.set_yticks([])
        else:
            ax.set_xlabel("Total 'Yes' or 'No' Probability")
            ax.set_xlim(0, 1)
            ax.set_xticks(np.linspace(0, 1, 5))

    plt.subplots_adjust(hspace=0.3, wspace=0.2)
    handles, labels = axs[0].get_legend_handles_labels()
    fig.legend(handles, labels, loc='upper right', bbox_to_anchor=(0.9, 1.07))
    plt.savefig(SAVE_DIR / 'overall_model_behavior_violin_condensed.png', bbox_inches="tight")
    plt.clf()


def create_model_agreement_plot(behavior: BehaviorDataRetriever):
    entries = []
    for task in TASKS:
        templates_a = filter_by_model(filter_by_task(behavior.behavior, task), MODELS[0]).template_id.unique()
        templates_b = filter_by_model(filter_by_task(behavior.behavior, task), MODELS[1]).template_id.unique()
        same_templates = [t for t in templates_a if t in templates_b]
        data = behavior.behavior[behavior.behavior.template_id.isin(same_templates)]
        data_a  = data[data.model == MODELS[0]].sort_values(['template_id', 'question_id']).answer_prob
        data_b  = data[data.model == MODELS[1]].sort_values(['template_id', 'question_id']).answer_prob
        spearman = spearmanr(data_a, data_b).statistic # type: ignore
        tv_dist = metrics.tv_distance_of_preds(data_a, data_b) # type: ignore
        entries.append(dict(
            task=task,
            tv_dist=tv_dist,
            spearman=spearman
        ))
    sns.barplot(pd.DataFrame.from_records(entries).sort_values('spearman', ascending=False), x='spearman', y='task', palette='viridis')
    plt.clf()
    df = pd.DataFrame.from_records(entries)
    ax = sns.scatterplot(data=df, x='spearman', y='tv_dist', hue='task', palette='viridis', s=100, legend=False)  # type: ignore
    for index, row in df.iterrows():
        ax.annotate(
            row['task'],
            (row['spearman'], row['tv_dist']),
            ha='center',
            textcoords="offset points",
            xytext=(0,5),
        )
    plt.gca().invert_yaxis()
    plt.ylim(.4, 0)
    plt.xlim(0, 1)

    pass

    plt.clf()
    sns.barplot(pd.DataFrame.from_records(entries).sort_values('tv_dist', ascending=False), x='tv_dist', y='task')
    plt.clf()


def main():
    dataset_data = DatasetDataRetriever()
    behavior = BehaviorDataRetriever()

    reset_dirs(dataset_data)
    # create_model_agreement_plot(behavior)
    create_behavior_task_summary_plot(behavior)
    dataset_creation_table(dataset_data, behavior)
    plot_all_dataset_info(dataset_data)
    plot_behavior_summary(behavior)
    for metric in tqdm(BASELINE_METRICS, desc="Plotting Baseline Results"):
        plot_baseline_results(metric)


if __name__ == '__main__':
    main()
