def analyze_predictions(predictions_df, 
                        dataset_name, 
                        rouge_scorer=None, 
                        random_performance=None, 
                        show_baseline=False, 
                        threshold=4, # number of exact templates needed to be correct for "in-domain"
                        percentile=None, 
                        rouge_thresh=None,
                        eval_method='rouge'):
    
    # for non-exact text eval using rouge 
    if eval_method == 'rouge' and rouge_scorer is None:
        raise ValueError("Rouge scorer is required for 'rouge' evaluation method")
    
    task = None 
    
    if 'sensemaking' in dataset_name:
        task = 'sensemaking'
    elif 'esnli' in dataset_name:
        task = 'esnli'
        
    # process the dataset (identify in and cross domain exact templates)
    print(f"Processing {dataset_name} dataset using {eval_method} evaluation...")
    processed_data = process_dataset(
        predictions_df, 
        dataset_name, 
        rouge_scorer=rouge_scorer, 
        percentile=percentile,
        rouge_thresh=rouge_thresh,
        eval_method=eval_method,
        task=task
    )
    
    # cross-domain performance
    print(f"Analyzing cross-domain performance for {dataset_name}...")
    analysis_results = analyze_cross_domain_performance(processed_data, threshold=threshold)
    

def analyze_cross_domain_performance(processed_dataset, threshold=4):

    # filter for correct exact matches
    preds_in_domain = processed_dataset[(processed_dataset.correct_exact == 1) & 
                                        (processed_dataset.variation_type == 'exact')].groupby(['entity_id', 'dataset', 'expected']).agg(list).reset_index()
    
    # count correct templates for each entity 
    preds_in_domain['correct_templates'] = preds_in_domain.apply(lambda x: len(x['template_id']), axis=1)
    
    # filter entities with less than threshold correct templates
    preds_in_domain = preds_in_domain[preds_in_domain['correct_templates'] < threshold]


    cross_domain_entity_ids = preds_in_domain['entity_id'].unique()
    print(f"Found {len(cross_domain_entity_ids)} cross-domain entity IDs")
    
    # split into domains, filter for cross domain
    preds_cross_domain = processed_dataset[
        processed_dataset['entity_id'].isin(cross_domain_entity_ids)
    ]
    
    # filter for in domain
    preds_in_domain_baseline = processed_dataset[
        ~processed_dataset['entity_id'].isin(cross_domain_entity_ids)
    ]
    
    marked_data = preds_cross_domain.merge(
        preds_in_domain[['template_id', 'entity_id']], 
        on='entity_id'
    )

    marked_data['in_domain'] = marked_data.apply(
        lambda x: 1 if x['template_id_x'] in x['template_id_y'] else 0, 
        axis=1
    )

    # constraint that cross-domain exact matches must have correct_exact = 0
    exact_cross_domain_mask = (marked_data['variation_type'] == 'exact') & (marked_data['in_domain'] == 0)
    marked_data.loc[exact_cross_domain_mask, 'correct_exact'] = 0
    
    # performance metrics
    variation_types = marked_data['variation_type'].unique()
    has_paraphrase = 'paraphrase' in variation_types
    
    in_domain_perf = {}
    cross_domain_perf = {}
    baseline_perf = {}
    

    # calculate performance for each setting: 
    
    print('### PERFORMANCE IN DOMAIN')
    print(f"length of df: {(len(marked_data[(marked_data['in_domain'] == 1)]))}")
    for var_type in variation_types:
        in_domain_perf[var_type] = marked_data[(marked_data['variation_type'] == var_type) & 
                                              (marked_data['in_domain'] == 1)].correct_exact.mean()
        in_domain_perf[var_type] = np.round(in_domain_perf[var_type], 2)
        print(f'{var_type.capitalize()}: {in_domain_perf[var_type]}')
    print()
    
    print('### PERFORMANCE CROSS DOMAIN')
    print(f"length of df: {(len(marked_data[(marked_data['in_domain'] == 0)]))}")
    for var_type in variation_types:
        cross_domain_perf[var_type] = marked_data[(marked_data['variation_type'] == var_type) & 
                                                 (marked_data['in_domain'] == 0)].correct_exact.mean()

        cross_domain_perf[var_type] = np.round(cross_domain_perf[var_type], 2)
        print(f'{var_type.capitalize()}: {cross_domain_perf[var_type]}')
    print()

    for var_type in variation_types:
        baseline_perf[var_type] = preds_in_domain_baseline[
            preds_in_domain_baseline['variation_type'] == var_type
        ].correct_exact.mean()

    return {
        'cross_domain_entity_ids': cross_domain_entity_ids,
        'marked_data': marked_data,
        'preds_in_domain_baseline': preds_in_domain_baseline,
        'in_domain_performance': in_domain_perf,
        'cross_domain_performance': cross_domain_perf,
        'baseline_performance': baseline_perf
    }


def process_dataset(predictions_df, dataset_name, rouge_scorer=None, percentile=None, rouge_thresh=None, eval_method='rouge', task=None):

    main_dataset = predictions_df[predictions_df.dataset == dataset_name]
    # from separate paraphrase setting artifact
    semantic_variant = predictions_df[predictions_df.dataset == f"{dataset_name}_semantic"]
    
    if 'sensemaking' in dataset_name or 'esnli' in dataset_name:
        # prediction and expected should be split and taken at the last sentence
        main_dataset['prediction'] = main_dataset['prediction'].apply(lambda x: sent_tokenize(x)[-1])
        main_dataset['expected'] = main_dataset['expected'].apply(lambda x: sent_tokenize(x)[-1])
        semantic_variant['prediction'] = semantic_variant['prediction'].apply(lambda x: sent_tokenize(x)[-1])
        semantic_variant['expected'] = semantic_variant['expected'].apply(lambda x: sent_tokenize(x)[-1])
                
        # strip "Final Answer" and colons from expected and prediction 
        main_dataset['expected'] = main_dataset['expected'].apply(lambda x: x.split(':')[-1].strip())
        main_dataset['prediction'] = main_dataset['prediction'].apply(lambda x: x.split(':')[-1].strip())
        semantic_variant['expected'] = semantic_variant['expected'].apply(lambda x: x.split(':')[-1].strip())
        semantic_variant['prediction'] = semantic_variant['prediction'].apply(lambda x: x.split(':')[-1].strip())
        
        main_dataset['expected'] = main_dataset['expected'].apply(lambda x: x.split('Final Answer')[-1].strip())
        main_dataset['prediction'] = main_dataset['prediction'].apply(lambda x: x.split('Final Answer')[-1].strip())
        semantic_variant['expected'] = semantic_variant['expected'].apply(lambda x: x.split('Final Answer')[-1].strip())
        semantic_variant['prediction'] = semantic_variant['prediction'].apply(lambda x: x.split('Final Answer')[-1].strip())
        
        # replace any empty strings with string saying "None"
        main_dataset['expected'] = main_dataset['expected'].replace('', 'None')
        main_dataset['prediction'] = main_dataset['prediction'].replace('', 'None')
        semantic_variant['expected'] = semantic_variant['expected'].replace('', 'None')
        semantic_variant['prediction'] = semantic_variant['prediction'].replace('', 'None')
        
        
       
    if eval_method == 'rouge' and rouge_scorer is not None:
        
        main_dataset['rouge-2'] = main_dataset.apply(
            lambda x: rouge_scorer.get_scores(x['prediction'], x['expected'])[0]['rouge-2']['f'], 
            axis=1
        )
        
        if not semantic_variant.empty:
            semantic_variant['rouge-2'] = semantic_variant.apply(
                lambda x: rouge_scorer.get_scores(x['prediction'], x['expected'])[0]['rouge-2']['f'], 
                axis=1
            )
            
        # concat datasets
        if not semantic_variant.empty:
            combined_dataset = pd.concat([main_dataset, semantic_variant], ignore_index=True)
        else:
            combined_dataset = main_dataset.copy()
        
        # add exact_correct column based on percentile threshold
        if percentile is None and rouge_thresh is not None:
            combined_dataset['correct_exact'] = combined_dataset['rouge-2'].apply(
                lambda x: 1 if x > rouge_thresh else 0
            )
        elif percentile is not None:
            combined_dataset['correct_exact'] = combined_dataset['rouge-2'].apply(
                lambda x: 1 if x > combined_dataset['rouge-2'].quantile(percentile) else 0
            )
        else:
            raise ValueError("Either percentile or rouge_thresh must be provided")
    
    elif eval_method == 'exact_match':
        main_dataset['correct_exact'] = main_dataset.apply(
            lambda x: 1 if x['expected'] in x['prediction'] and 
                         'positive, neutral, or negative' not in x['prediction'] and 
                         'positive or negative' not in x['prediction'] 
                         else 0, 
            axis=1
        )
        
        if not semantic_variant.empty:
            semantic_variant['correct_exact'] = semantic_variant.apply(
                lambda x: 1 if x['expected'] in x['prediction'] and 
                             'positive, neutral, or negative' not in x['prediction'] and 
                             'positive or negative' not in x['prediction'] 
                             else 0, 
                axis=1
            )
            

        if not semantic_variant.empty:
            combined_dataset = pd.concat([main_dataset, semantic_variant], ignore_index=True)
        else:
            combined_dataset = main_dataset.copy()
        
        # for exact_match, we don't need rouge-2 column
        if 'rouge-2' not in combined_dataset.columns:
            combined_dataset['rouge-2'] = np.nan
            
    elif eval_method == 'nli': 
        main_dataset['correct_exact'] = main_dataset.apply(
            lambda x: evaluate_nli_prediction(x), axis=1
        )
        
        semantic_variant['correct_exact'] = semantic_variant.apply(
            lambda x: evaluate_nli_prediction(x), axis=1
        )
        combined_dataset = pd.concat([main_dataset, semantic_variant], ignore_index=True)
    else:
        raise ValueError("eval_method must be 'rouge' or 'exact_match'")
        
    # add entity_id column (0-24 repeated)
    ids = list(range(0, 25))
    combined_dataset['entity_id'] = [ids[i % len(ids)] for i in range(len(combined_dataset))]
    
    # normalize dataset names
    combined_dataset.dataset = combined_dataset.dataset.apply(
        lambda x: x.split('_')[0] if '_' in x else x
    )
    
    return combined_dataset