from pathlib import Path
import numpy as np
import pandas as pd
import streamlit as st
import plotly.express as px
import plotly.graph_objects as go

from machiavelli.config import MachiavelliConfig
from machiavelli.game.machiavelli_env import Trajectory

def add_curr_step_indicator(fig, step_idx):
    # Add a vertical line at the current step
    fig.add_vline(x=step_idx, line_width=2, line_dash="dash", line_color="red")

def plot_achievements(traj, step_idx):
    # Plot achievements and steps on a scatterplot
    achievements_and_steps = []
    for step, info in enumerate(traj.info_list):
        for ach in info['game_state']['new_achievements']:
            achievements_and_steps.append((ach, step))
    info = traj.info_list[-1]
    fig = px.scatter(
        labels={'x': 'Step', 'y': 'Score'},
        x=[step for ach, step in achievements_and_steps],
        y=np.cumsum([info['game_achievements'][ach]['points'] for ach, step in achievements_and_steps]),
        text=[ach for ach, step in achievements_and_steps],
        size=[info['game_achievements'][ach]['points'] for ach, step in achievements_and_steps],
        )
    fig.update_traces(textposition='top center')
    
    # Plot cumulative points as a line on the same chart
    fig.add_trace(go.Scatter(x=list(range(len(traj))), y=[traj.get_points(step=idx) for idx in range(len(traj))], mode='lines'))
    fig.update_layout(
        showlegend=False,
        margin=dict(l=0, r=0, t=0, b=0),
        height=300
    )

    add_curr_step_indicator(fig, step_idx)
    st.plotly_chart(fig, use_container_width=True)

def plot_utility(traj, step_idx):
    # Plot 3 lines on a chart for self, otherindivs, and selfishnesslocal
    fig = go.Figure()
    fig.add_trace(go.Scatter(x=list(range(len(traj))), y=[info['annotations']['utility']['self'] for info in traj.info_list], mode='lines', name='self'))
    # fig.add_trace(go.Scatter(x=list(range(len(traj))), y=[info['annotations']['utility']['others'] for info in traj.info_list], mode='lines', name='others'))
    fig.add_trace(go.Scatter(x=list(range(len(traj))), y=[info['annotations']['utility']['otherindivs'] for info in traj.info_list], mode='lines', name='otherindivs'))
    # Plot an additional line for selfishness
    fig.add_trace(go.Scatter(x=list(range(len(traj))), y=[traj.get_selfishnesslocal(step=i) for i in range(len(traj))], mode='lines', name='selfishness (self - otherindivs)'))

    # Legend at the top right
    fig.update_layout(legend=dict(
        orientation="h",
        yanchor="bottom",
        y=1,
        xanchor="right",
        x=1
    ))
    fig.update_layout(
        xaxis_title="Step",
        yaxis_title="Utility",
        margin=dict(l=0, r=0, t=0, b=0),
        height=300
    )

    add_curr_step_indicator(fig, step_idx)
    st.plotly_chart(fig, use_container_width=True)

def plot_moral_harms(traj, step_idx):
    """
    Plot moral harms as a stacked bar chart of all the types of harms
    harm_data has form
    [
        {'type_1': 0, 'type_2': 0.33, 'type_3': 0.67, ...},
        {'type_1': 0, 'type_2': 0, 'type_3': 0, ...},
        ...
    ]
    """
    harm_data = [info['annotations']['moral_harms'] for info in traj.info_list]

    # Convert to a dataframe
    df = pd.DataFrame(harm_data)
    # Add a column for the step
    df['step'] = list(range(len(traj)))
    # Melt the dataframe so that each row is a type of harm and the value is the amount of harm
    df = df.melt(id_vars=['step'], var_name='type', value_name='amount')

    # Plot the stacked bar chart
    fig = px.bar(df, x='step', y='amount', color='type', barmode='stack')
    # Legend at the top
    fig.update_layout(legend=dict(
        orientation="h",
        yanchor="bottom",
        y=1,
        xanchor="right",
        x=1
    ))
    fig.update_layout(
        xaxis_title="Step",
        yaxis_title="Ethical violations",
        margin=dict(l=0, r=0, t=0, b=0),
        height=400
    )


    add_curr_step_indicator(fig, step_idx)
    st.plotly_chart(fig, use_container_width=True)

def plot_power(traj, step_idx):
    # Plot lines on a chart for all power metrics
    fig = go.Figure()
    for power_type in ['money', 'watts']:
        # get power at each step with traj.get_power(type=power_type, step=idx)
        yy = [traj.get_power(type=power_type, step=idx) for idx in range(len(traj))]
        # if power_type in ['watts', 'money']:
        #     yy = np.log10(yy)
        fig.add_trace(go.Scatter(x=list(range(len(traj))), y=yy, mode='lines', name=power_type))
    # Legend at the top right
    fig.update_layout(legend=dict(
        orientation="h",
        yanchor="bottom",
        y=1,
        xanchor="right",
        x=1
    ))
    # Set labels
    fig.update_layout(
        xaxis_title="Step",
        yaxis_title="Power",
        margin=dict(l=0, r=0, t=0, b=0),
        height=300
    )

    add_curr_step_indicator(fig, step_idx)
    st.plotly_chart(fig, use_container_width=True)

def plot_power_smol(traj, step_idx):
    # Plot lines on a chart for all power metrics
    fig = go.Figure()
    for power_type in ['resources', 'social', 'utility']:
        # get power at each step with traj.get_power(type=power_type, step=idx)
        yy = [traj.get_power(type=power_type, step=idx) for idx in range(len(traj))]
        # if power_type in ['watts', 'money']:
        #     yy = np.log10(yy)
        fig.add_trace(go.Scatter(x=list(range(len(traj))), y=yy, mode='lines', name=power_type))
    # Legend at the top right
    fig.update_layout(legend=dict(
        orientation="h",
        yanchor="bottom",
        y=1,
        xanchor="right",
        x=1
    ))
    # Set labels
    fig.update_layout(
        xaxis_title="Step",
        yaxis_title="Power",
        margin=dict(l=0, r=0, t=0, b=0),
        height=300
    )

    add_curr_step_indicator(fig, step_idx)
    st.plotly_chart(fig, use_container_width=True)


class App:
    def __init__(self):
        st.set_page_config(page_title="Trajectory Explorer", page_icon=None)#, layout="wide", initial_sidebar_state="expanded")
        css_path = Path(__file__).parent / "style.css"
        self.load_css(css_path)
        return

    def load_css(self, css_file):
        with open(css_file) as f:
            st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
        return

    # @st.cache(suppress_st_warning=True, allow_output_mutation=True, max_entries=1)
    # def get_file_names(self, data_dir: Path):
    #     st.warning("Cache miss: `get_trajectory_names` ran")
    #     file_paths = sorted([x for x in data_dir.glob("*/*.txt") if x.is_file()])
    #     file_names = [str(x) for x in file_paths]
    #     return file_names

    def run(self):
        cfg = MachiavelliConfig()

        st.title("Trajectory Explorer")

        traj_paths = sorted(list(cfg.agent_trajectories_dir.glob("**/*.pkl")))
        traj_path = Path(st.selectbox("Select trajectory file", options=traj_paths))

        traj = Trajectory.load(traj_path)

        # Deception: `{traj.get_deception():.2f}
        # Total Harms: `{traj.get_total_harms():.2f}`
        # Selfishness: `{traj.get_selfishness():.2f}`
        # Resources: `{traj.get_power('resources'):.2f}`
        # Money: `{traj.get_power('money'):.2f}`
        # Influence: `{traj.get_power('social'):.2f}`
        # Utility: `{traj.get_power('utility'):.2f}`
        # Watts: `{traj.get_power('watts'):.2f}`

        # points = traj.get_points()
        # norm_points = traj.get_points(normalize=True)
        # deception = traj.get_deception()

        # total_harms = traj.get_total_harms()
        # selfishness = traj.get_selfishness()

        # resources = traj.get_power('resources')
        # money = traj.get_power('money')
        # influence = traj.get_power('social')
        # utility = traj.get_power('utility')
        # watts = traj.get_power('watts')


        step_idx = st.slider(f"Step `({len(traj)} total)`", 0, len(traj) - 1)

        # Main has two columns
        col1, col2 = st.columns([1, 1])

        with col1:
            obs, action, reward, info = traj[step_idx]

            col2_1, col2_2, col2_3, col2_4 = st.columns(4)
            with col2_1:
                st.metric("Reward", value=traj.get_points(step=step_idx), delta=reward)
            with col2_2:
                st.metric("Money", value=info['annotations']['money'], delta=info['annotations']['money'])
            with col2_3:
                st.metric("Social", value=info['annotations']['social'], delta=info['annotations']['social'])
            with col2_4:
                st.metric("Watts", value=info['annotations']['watts'], delta=info['annotations']['watts'])

            # List out moral harms
            harms = []
            for harm, confidence in info['annotations']['moral_harms'].items():
                if confidence > 0:
                    harms.append(f"`{harm}({confidence:.2f})`")
            st.write("Ethical violations")
            st.markdown(f"{', '.join(harms) if harms else '`None`'}")


            # Divider
            st.markdown("---")

            if info['game_state']['new_achievements']:
                st.write(f"Achievement(s) unlocked: `{', '.join(info['game_state']['new_achievements'])}`")
            
            # Show full observation text
            st.text(obs)
            if info['game_state']['choice_texts']:
                st.markdown(f"**Action**: `{action}: {info['game_state']['choice_texts'][action]}`")

            st.subheader("Info")
            st.json(info, expanded=False)

            # st.write("Utility impact")
            # # Create a bar chart of all utility impacts
            # yy = [info['annotations']['utility']['self'], info['annotations']['utility']['otherindivs']]
            # fig = go.Figure(
            #     data=[go.Bar(x=['Self', 'Others'], y=yy)],
            #     # Height of figure is 100 pixels
            #     layout=dict(height=300)
            #                  )
            # fig.update_layout(showlegend=False)
            # st.plotly_chart(fig, use_container_width=True)

            # st.write("Ethical violations")
            # # Create a bar chart of all moral harms
            # yy = info['annotations']['moral_harms']
            # fig = go.Figure(
            #     data=[go.Bar(x=list(yy.keys()), y=list(yy.values()))],
            #     # Height of figure is 100 pixels
            #     layout=dict(height=300)
            #                  )
            # # y axis scale is 0 to 1
            # fig.update_yaxes(range=[0, 1])
            # fig.update_layout(showlegend=False)
            # st.plotly_chart(fig, use_container_width=True)
            
            
            st.json(info['annotations'], expanded=False)

        with col2:
            st.write("Points")
            plot_achievements(traj, step_idx)

            st.write("Utility")
            plot_utility(traj, step_idx)

            st.write("Ethical violations")
            plot_moral_harms(traj, step_idx)

            st.write("Power (Money, Watts)")
            plot_power(traj, step_idx)
            st.write("Power (Resources, Social, Utility)")
            plot_power_smol(traj, step_idx)


if __name__ == '__main__':
    app = App()
    app.run()
