import plotly.express as px
import torch.nn.functional as F
from abstract_cf.text_generation.learned_abstraction import LearnedAbstractionPipeline
import pandas as pd 
import numpy as np


# TODO check that this is correct
def analyze_career_distribution(
    biographies, 
    profession_classifier: LearnedAbstractionPipeline, 
    sort=False
) -> pd.DataFrame:
    all_probs = []
    
    for bio in biographies:
        output = profession_classifier.predict(bio)
        probs = F.softmax(output.logits, dim=-1).cpu()
        all_probs.append(probs.detach().numpy())
    
    mean_distribution = np.mean(all_probs, axis=0)
    labels = [profession_classifier.id_to_label[i] for i in range(len(mean_distribution[0]))]
    
    df = pd.DataFrame({
        'profession': labels,
        'probability': mean_distribution[0]
    })
    if sort:
        df = df.sort_values('probability', ascending=False)
    
    return df


def plot_career_distribution(df, title:str='Career Distribution'):
    fig = px.bar(
        df,
        x='profession', 
        y='probability',
        title=title,
        labels={'profession': 'profession', 'probability': 'probabilty'}
    )
    # Improve readability of x-axis labels
    fig.update_layout(xaxis_tickangle=-45)
    return fig

