import numpy as np
from plotly import graph_objects as go
import pandas as pd
import plotly.express as px

def get_start_of_memorization(df, 
                              training_dataset, 
                              eval_dataset, 
                              metric,
                              eval_column='eval_dataset',
                              columns_preserve_variance=[],
                              approach="contextual_memorization",
                              memorization_threshold=None,
    ):
    assert training_dataset in df[eval_column].unique()
    assert eval_column in df.columns
    assert eval_dataset in df[eval_column].unique()
    assert metric in df.columns
    assert approach in ["contextual_memorization", "counterfactual_memorization", "recollection_memorization"]
    if len(columns_preserve_variance) > 0:
        assert all(column in df.columns for column in columns_preserve_variance)
    

    
    df = df[df[eval_column].isin([training_dataset, eval_dataset])].copy()
    df_original = df[[eval_column, 'epoch', metric] + columns_preserve_variance].copy()
    df = df.groupby([eval_column, 'epoch']).aggregate({metric: 'mean'}).reset_index()


    
    epochs = df['epoch'].unique()
    epochs = np.sort(epochs)

    
    
    list_training_recollection = []
    threshold_computed = []
    computed_memorization = []
    epoch_start_of_memorization = None
    training_recollection_at_start_of_memorization = None

    epochs_higher_than_threshold = []
    best_contextual_recollection = None
    if approach == "contextual_memorization":
        if metric == "correct":
            best_contextual_recollection = df[df[eval_column] == eval_dataset].groupby(
                ['epoch']
            ).aggregate({metric: 'max'}).reset_index()[metric].max()
        else:
            best_contextual_recollection = df[df[eval_column] == eval_dataset].groupby(
                ['epoch']
            ).aggregate({metric: 'min'}).reset_index()[metric].min()
    elif approach == "recollection_memorization":
        if memorization_threshold is None:
            if metric == "correct":
                best_contextual_recollection = 0.95
            else:
                best_contextual_recollection = 0.2
        else:
            best_contextual_recollection = memorization_threshold        
    threshold = 0
    # print("Best contextual recollection:", best_contextual_recollection)

    
    
    # compare with threshold
    for epoch in epochs:
        df_item = df[df['epoch'] == epoch]
        if metric == "correct":
            if approach in ["contextual_memorization", "recollection_memorization"]:
                memorization = df_item[df_item[eval_column] == training_dataset][metric].item() - best_contextual_recollection
            else:
                memorization = df_item[df_item[eval_column] == training_dataset][metric].item() - df_item[df_item[eval_column] == eval_dataset][metric].item()
        elif metric == "target_token_negative_log_prob":
            if approach in ["contextual_memorization", "recollection_memorization"]:
                memorization = best_contextual_recollection - df_item[df_item[eval_column] == training_dataset][metric].item()
            else:
                memorization = df_item[df_item[eval_column] == eval_dataset][metric].item() - df_item[df_item[eval_column] == training_dataset][metric].item()

        if memorization > threshold:
            epochs_higher_than_threshold.append(True)
        else:
            epochs_higher_than_threshold.append(False)

        threshold_computed.append(memorization)
        list_training_recollection.append(df_item[df_item[eval_column] == training_dataset][metric].item())


        
    memorization_has_started=True
    epochs_higher_than_threshold = np.array(epochs_higher_than_threshold)
    if not np.all(epochs_higher_than_threshold): # in at least one epoch, memorization is below threshold
        # find the last epoch after which memorization is above threshold
        last_epoch_below_threshold = np.where(epochs_higher_than_threshold == False)[0][-1]
        # print("Last epoch below threshold:", last_epoch_below_threshold)
        if last_epoch_below_threshold < len(epochs) - 1:
            epoch_start_of_memorization = epochs[last_epoch_below_threshold+1]
            training_recollection_at_start_of_memorization = list_training_recollection[last_epoch_below_threshold+1]
        else:
            epoch_start_of_memorization = epochs[-1]
            training_recollection_at_start_of_memorization = list_training_recollection[-1]
            memorization_has_started = False
            
    else:
        epoch_start_of_memorization = epochs[0]
        training_recollection_at_start_of_memorization = list_training_recollection[0]
        
    # print(epochs_higher_than_threshold)
    # print(epoch_start_of_memorization)


    # compute quantitative memorization
    for key, df_item in df_original.groupby(columns_preserve_variance + ['epoch']):
        if df_item[eval_column].nunique() != 2:
            continue

        baseline = None
        if approach in ["contextual_memorization", "recollection_memorization"]:
            baseline = best_contextual_recollection
        else:
            baseline = df_item[df_item[eval_column] == eval_dataset][metric].item()

        if metric == "correct":
            if baseline == 1:
                memorization = 0
            elif df_item[df_item[eval_column] == training_dataset][metric].item() < baseline:
                memorization = 0
            else:
                memorization = (df_item[df_item[eval_column] == training_dataset][metric].item() - baseline) / (1 - baseline)

            # if approach == "counterfactual_memorization":
            #     memorization = memorization * (1 - baseline)
        
        elif metric == "target_token_negative_log_prob":
            if baseline == 0:
                memorization = 0
            elif df_item[df_item[eval_column] == training_dataset][metric].item() > baseline:
                memorization = 0
            else:
                memorization = 1 - df_item[df_item[eval_column] == training_dataset][metric].item() / baseline

            # if approach == "counterfactual_memorization":
            #     memorization = memorization * baseline

        if approach == "recollection_memorization":
            memorization = 1 if memorization > 0 else 0

        memorization_results = {
            "memorization": memorization
        }
        for i, column in enumerate(columns_preserve_variance + ['epoch']):
            memorization_results[column] = key[i]

        computed_memorization.append(memorization_results) 

    df_computed_memorization = pd.DataFrame(computed_memorization)
    

    # before the start of memorization, degree of memorization is 0 or nan
    df_computed_memorization['memorization'] = df_computed_memorization.apply(
        lambda x: x['memorization'] if x['epoch'] >= epoch_start_of_memorization else np.nan, axis=1
        # lambda x: x['memorization'] if x['epoch'] >= epoch_start_of_memorization else 0, axis=1
    )
    

    
    
    return df_computed_memorization, epoch_start_of_memorization, best_contextual_recollection, memorization_has_started




# def best_generalization(fig, 
#                         color,
#                         df, 
#                         eval_dataset, 
#                         metric,
#                         eval_column='eval_dataset',
#                         add_best_generalization_line_horizontal=False, 
#                         horizontal_line_legend=None,
#                         show_legend=False 
#                         ):

#     assert metric in ["target_token_negative_log_prob", "correct"]
#     assert eval_column in df.columns
#     assert eval_dataset in df[eval_column].unique()
#     assert "epoch" in df.columns
#     df = df[df[eval_column].isin([eval_dataset])].copy()
#     df = df.groupby([eval_column, 'epoch']).aggregate({metric: 'mean'}).reset_index()


#     best_generalization_value = None
#     best_epoch = None
#     if metric == "target_token_negative_log_prob":
#         best_generalization_value = df[df[eval_column] == eval_dataset][metric].min()
#         best_epoch = df[(df[eval_column] == eval_dataset) & (df[metric] == best_generalization_value)]['epoch'].min()

#     elif metric == "correct":
#         best_generalization_value = df[df[eval_column] == eval_dataset][metric].max()
#         best_epoch = df[(df[eval_column] == eval_dataset) & (df[metric] == best_generalization_value)]['epoch'].min()        
        
#     else:
#         raise ValueError()

#     # print(best_epoch, best_generalization_value)

#     if best_epoch == None:
#         return fig, best_epoch, best_generalization_value

#     if add_best_generalization_line_horizontal:
#         if horizontal_line_legend == None:
#             horizontal_line_legend = eval_dataset
        
#         if show_legend:
#             # (horizontal line) add a trace for best_language_generalization
#             fig.add_trace(
#                 go.Scatter(
#                     x=[epoch for epoch in df['epoch'].unique()],
#                     y=[best_generalization_value for epoch in df['epoch'].unique()],
#                     name=horizontal_line_legend,
#                     marker=dict(
#                         color=color,
#                         size=10,
#                     ),
#                 )
#             )
#         else:
#             fig.add_hline(
#                 y=best_generalization_value,
#                 line_dash='dot', 
#                 line_color=color,
#                 annotation_text=f"{round(best_generalization_value, 2)}", 
#                 annotation_position="left",
#                 annotation_font=dict(
#                     color=color,
#                     size=10
#                 ),
#             )

#     if False:
#         # vertical line
#         fig.add_vline(
#             x=best_epoch,
#             line_dash='dot', 
#             line_color=color,
#             annotation_text=f"{int(best_epoch)}", 
#             annotation_position="bottom",
#             annotation_font=dict(
#                 color=color,
#                 size=10
#             ),
#         )

#     return fig, best_epoch, best_generalization_value



def get_arrow_line(fig, x0, y0, x1, y1, color, annotation_text):
    arrowhead = 2
    arrowsize = 1
    arrowwidth = 2

    # Arrow from point A to B
    fig.add_annotation(
        x=x1, y=y1,
        ax=x0, ay=y0,
        xref="x", yref="y",
        axref="x", ayref="y",
        text="",  # No label here
        showarrow=True,
        arrowhead=arrowhead,
        arrowsize=arrowsize,
        arrowwidth=arrowwidth,
        arrowcolor=color
    )

    # Arrow from point B to A (reverse direction)
    fig.add_annotation(
        x=x0, y=y0,
        ax=x1, ay=y1,
        xref="x", yref="y",
        axref="x", ayref="y",
        text="",  # No label here
        showarrow=True,
        arrowhead=arrowhead,
        arrowsize=arrowsize,
        arrowwidth=arrowwidth,
        arrowcolor=color
    )

    # Add center label
    fig.add_annotation(
        x=(x0 + x1) / 2,
        y=(y0 + y1) / 2,
        text=annotation_text,
        showarrow=False,
        font=dict(size=10),
        font_color="black",
        bgcolor="white",
        opacity=0.85
    )
    
    return fig


def plot_vline_with_conflict(fig, 
                             v_line_dict, 
                             width=0.5, 
                             line_dash='dot', 
                             font_size=10,
                             annotation_position='top',
    ):
    # print(v_line_dict)
    for epoch in v_line_dict:
        memorization_has_started = sum([flag for _, flag in v_line_dict[epoch]])
        if len(v_line_dict[epoch]) == 1:
            
            fig.add_vline(x=epoch, 
                    line_dash=line_dash, 
                    line_color=v_line_dict[epoch][0][0],
                    annotation_text=f"{'> ' if not memorization_has_started else ''}{int(epoch)}",
                    annotation_font=dict(
                        color='gray',
                        size=font_size,
                    ),
                    annotation_position=annotation_position,
                    annotation_yshift=-20 if annotation_position == 'top' else 20,
                    annotation_bgcolor="white",
                    annotation_opacity=0.85,
            )
        else:
            
            fig.add_vline(x=epoch, 
                    line_dash=line_dash, 
                    line_color='rgba(255, 0, 0, 0)',
                    annotation_text=f"{'> ' if not v_line_dict[epoch][0][1] else ''}{int(epoch)}",
                    annotation_font=dict(
                        color='gray',
                        size=font_size,
                    ),
                    annotation_position=annotation_position,
                    annotation_yshift=-20 if annotation_position == 'top' else 20,
                    annotation_bgcolor="white",
                    annotation_opacity=0.85,
            )

            for i, epoch_split in enumerate(np.linspace(epoch - width/2, epoch + width/2, len(v_line_dict[epoch]))):
                fig.add_vline(x=epoch_split, 
                    line_dash=line_dash, 
                    line_color=v_line_dict[epoch][i][0],
                )
    
    return fig



# def get_start_of_memorization(fig, 
#                               color, 
#                               df, 
#                               training_dataset, 
#                               eval_dataset, 
#                               metric,
#                               eval_column='eval_dataset',
#                               vertical_line_or_star="vertical_line",
#                               columns_preserve_variance=[],
#                               our_approach=True,
#     ):
#     assert training_dataset in df[eval_column].unique()
#     assert eval_column in df.columns
#     assert eval_dataset in df[eval_column].unique()
#     assert metric in df.columns
#     assert vertical_line_or_star in ["vertical_line", "star"]
#     if len(columns_preserve_variance) > 0:
#         assert all(column in df.columns for column in columns_preserve_variance)
    

    
#     df = df[df[eval_column].isin([training_dataset, eval_dataset])].copy()
#     df_original = df[[eval_column, 'epoch', metric] + columns_preserve_variance].copy()
#     df = df.groupby([eval_column, 'epoch']).aggregate({metric: 'mean'}).reset_index()

    
#     if our_approach:
#         # print("Comparing with current best")
#         df = df.sort_values(['epoch'])

#         # split
#         df_train = df[df[eval_column] == training_dataset].copy()
#         df_eval = df[df[eval_column] == eval_dataset].copy()

        
#         if metric == "correct":
#             df_eval[metric] = df_eval[metric].expanding().max()
#         else:
#             df_eval[metric] = df_eval[metric].expanding().min()

#         # merge
#         df = pd.concat([df_train, df_eval], axis=0)


#         if len(columns_preserve_variance) > 0:
#             df_train = df_original[df_original[eval_column] == training_dataset].copy()
#             df_eval = df_original[df_original[eval_column] == eval_dataset].copy()

#             # print(df_eval)

#             list_df_eval = []
#             for key, df_item in df_eval.groupby(columns_preserve_variance):
#                 df_item = df_item.sort_values(['epoch'])
#                 if metric == "correct":
#                     df_item[metric] = df_item[metric].expanding().max()
#                 else:
#                     df_item[metric] = df_item[metric].expanding().min()
#                 # print()
#                 # print(df_item)
#                 list_df_eval.append(df_item)
            
#             # merge
#             df_original = pd.concat([df_train] + list_df_eval, axis=0)
        
#         else:
#             df_original = df.copy()

           
    
    
#     epochs = df['epoch'].unique()
#     epochs = np.sort(epochs)

    
    
    
#     list_training_recollection = []
#     threshold_computed = []
#     computed_memorization = []
#     optimal_epoch = None
#     training_recollection_at_start_of_memorization = None

#     epochs_higher_than_threshold = []

#     # when the two differs
#     threshold = 0
#     # if metric == "target_token_negative_log_prob":
#     #     threshold = 1
#     # elif metric == "correct":
#     #     threshold = 0
#     # else:
#     #     raise ValueError(metric)


#     # compare with threshold
#     for epoch in epochs:
#         df_item = df[df['epoch'] == epoch]

#         if metric == "correct":
#             memorization = df_item[df_item[eval_column] == training_dataset][metric].item() - df_item[df_item[eval_column] == eval_dataset][metric].item()
#         elif metric == "target_token_negative_log_prob":
#             memorization = df_item[df_item[eval_column] == eval_dataset][metric].item() - df_item[df_item[eval_column] == training_dataset][metric].item()

#         if memorization > threshold:
#             epochs_higher_than_threshold.append(True)
#         else:
#             epochs_higher_than_threshold.append(False)
            
#         threshold_computed.append(memorization)
#         list_training_recollection.append(df_item[df_item[eval_column] == training_dataset][metric].item())


        
#     memorization_has_started=True
#     if our_approach:
#         if metric == "target_token_negative_log_prob":
#             best_generalization = df[df[eval_column] == eval_dataset][metric].min()
#             optimal_epoch = df[(df[eval_column] == eval_dataset) & (df[metric] == best_generalization)]['epoch'].min()

#         elif metric == "correct":
#             best_generalization = df[df[eval_column] == eval_dataset][metric].max()
#             optimal_epoch = df[(df[eval_column] == eval_dataset) & (df[metric] == best_generalization)]['epoch'].min()        
        
#         training_recollection_at_start_of_memorization = df[(df[eval_column] == training_dataset) & (df['epoch'] == optimal_epoch)][metric].item()        

#     else:    
#         # counterfactual memorization
#         epochs_higher_than_threshold = np.array(epochs_higher_than_threshold)
#         if not np.all(epochs_higher_than_threshold): # in at least one epoch, memorization is below threshold
#             # find the last epoch after which memorization is above threshold
#             last_epoch_below_threshold = np.where(epochs_higher_than_threshold == False)[0][-1]
#             # print("Last epoch below threshold:", last_epoch_below_threshold)
#             if last_epoch_below_threshold < len(epochs) - 1:
#                 optimal_epoch = epochs[last_epoch_below_threshold+1]
#                 training_recollection_at_start_of_memorization = list_training_recollection[last_epoch_below_threshold+1]
#             else:
#                 optimal_epoch = epochs[-1]
#                 training_recollection_at_start_of_memorization = list_training_recollection[-1]
#                 memorization_has_started = False
                
#         else:
#             optimal_epoch = epochs[0]
#             training_recollection_at_start_of_memorization = list_training_recollection[0]
            



#     # compute quantitative memorization
#     for key, df_item in df_original.groupby(columns_preserve_variance + ['epoch']):
#         if df_item[eval_column].nunique() != 2:
#             continue 
#         if metric == "correct":
#             if df_item[df_item[eval_column] == eval_dataset][metric].item() == 1:
#                 memorization = 0
#             elif df_item[df_item[eval_column] == training_dataset][metric].item() < df_item[df_item[eval_column] == eval_dataset][metric].item():
#                 memorization = 0
#             else:
#                 memorization = (df_item[df_item[eval_column] == training_dataset][metric].item() - df_item[df_item[eval_column] == eval_dataset][metric].item()) / (1 - df_item[df_item[eval_column] == eval_dataset][metric].item())
#         elif metric == "target_token_negative_log_prob":
#             if df_item[df_item[eval_column] == eval_dataset][metric].item() == 0:
#                 memorization = 0
#             elif df_item[df_item[eval_column] == training_dataset][metric].item() > df_item[df_item[eval_column] == eval_dataset][metric].item():
#                 memorization = 0
#             else:
#                 memorization = 1 - df_item[df_item[eval_column] == training_dataset][metric].item() / df_item[df_item[eval_column] == eval_dataset][metric].item()

#         memorization_results = {
#             "memorization": memorization
#         }
#         for i, column in enumerate(columns_preserve_variance + ['epoch']):
#             memorization_results[column] = key[i]

#         computed_memorization.append(memorization_results) 

#     df_computed_memorization = pd.DataFrame(computed_memorization)
    
#     # before the start of memorization, degree of memorization is 0
#     df_computed_memorization['memorization'] = df_computed_memorization.apply(
#         lambda x: x['memorization'] if x['epoch'] >= optimal_epoch else np.nan, axis=1
#         # lambda x: x['memorization'] if x['epoch'] >= optimal_epoch else 0, axis=1
#     )
    





#     if fig is not None:
#         if optimal_epoch is not None:
#             if vertical_line_or_star == "vertical_line":
#                 fig.add_vline(
#                     x=optimal_epoch,
#                     line_dash='dot', 
#                     line_color=color,
#                     annotation_text=f"{int(optimal_epoch)}" if memorization_has_started else f"> {int(optimal_epoch)}", 
#                     annotation_position="bottom",
#                     annotation_font=dict(
#                         color="gray",
#                         size=10,
#                     ),
#                 )

#             else:
            
#                 fig.add_scatter(
#                     x=[optimal_epoch],
#                     y=[training_recollection_at_start_of_memorization],
#                     mode='markers',
#                     marker=dict(size=10, color=color, symbol='x'),
#                     showlegend=True,
#                     name=f"({int(optimal_epoch)}, {round(training_recollection_at_start_of_memorization, 2)})" if memorization_has_started else f"> {int(optimal_epoch)}",
#                 )

    
    
#     return fig, df_computed_memorization, optimal_epoch

