import os
import json
from collections import defaultdict
import matplotlib.pyplot as plt
import numpy as np

def analyze_runs(base_dir, key='main', plot=False):

    dirs = [os.path.join(base_dir, d) for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d)) and 'best' not in d]

    stats = defaultdict(list)
    for dir in dirs:
        eval_dir = dir + '/eval_training'

        json_files = [os.path.join(eval_dir, f) for f in os.listdir(eval_dir) if f.endswith('.json')]
        # find the json file with the highest number
        json_files.sort(key=lambda x: int(x.split('/')[-1].split('.')[0]))

        if not json_files:
            print(f'No JSON files found in {eval_dir}')
            return

        data = json.load(open(json_files[-1]))


        main_data = data[key]

        best_scores = main_data['sub_results']['best_scores']

        if len(best_scores) == 3:
            VALENCE_INDEX = 0
            AROUSAL_INDEX = 1
            DOMINANCE_INDEX = 2
            valence_cor = best_scores[VALENCE_INDEX]
            arousal_cor = best_scores[AROUSAL_INDEX]
            dominance_cor = best_scores[DOMINANCE_INDEX]
            print(f'Valence correlation: {valence_cor}')
            print(f'Arousal correlation: {arousal_cor}')
            print(f'Dominance correlation: {dominance_cor}')

            stats['valence'].append(valence_cor)
            stats['arousal'].append(arousal_cor)
            stats['dominance'].append(dominance_cor)
        else:
            index = 0
            cor = best_scores[0]
            print(f'Correlation: {cor}')
            stats['correlation'].append(cor)

        if plot:
            result = main_data['encoded_by_target']

            # Dimensions to plot
            #dims = [1, 2]
            dims = [0]

            # Confidence level (e.g. 95%)
            confidence_z = 1.96

            if len(dims) == 2:
                plt.figure(figsize=(10, 8))
                ax = plt.gca()

                for word, vectors in result.items():
                    vectors = np.array(vectors)
                    if vectors.shape[0] < 2:
                        continue

                    selected = vectors[:, dims]
                    mean_vec = np.mean(selected, axis=0)
                    std_vec = np.std(selected, axis=0)
                    se_vec = std_vec / np.sqrt(len(selected))
                    ci_radius = confidence_z * se_vec

                    # Scatter points
                    plt.scatter(selected[:, 0], selected[:, 1], alpha=0.3)

                    # Mean point
                    plt.scatter(mean_vec[0], mean_vec[1], color='red', edgecolor='black', s=100)

                    # Confidence ellipse approximation (circle)
                    ellipse = plt.Circle(mean_vec, ci_radius[0], color='black', fill=False, alpha=0.5)
                    ax.add_patch(ellipse)

                    # Label the word
                    plt.text(mean_vec[0], mean_vec[1], word, fontsize=10, ha='center', va='center')

                plt.xlabel(f"Dimension {dims[0]}")
                plt.ylabel(f"Dimension {dims[1]}")
                plt.title("Word-wise Encodings: Means and Confidence Intervals (2D)")
                plt.grid(True)
                plt.tight_layout()
                plt.show()

            elif len(dims) == 1:
                dim = dims[0]
                plt.figure(figsize=(10, 6))

                for i, (word, vectors) in enumerate(sorted(result.items())):
                    vectors = np.array(vectors)
                    if vectors.shape[0] < 2:
                        continue

                    values = vectors[:, dim]
                    mean_val = np.mean(values)
                    std_val = np.std(values)
                    se_val = std_val / np.sqrt(len(values))
                    ci_radius = confidence_z * se_val

                    # Plot as a point with CI bar
                    plt.errorbar(
                        i, mean_val, yerr=ci_radius, fmt='o', color='blue',
                        ecolor='lightgray', elinewidth=3, capsize=5
                    )
                    plt.text(i, mean_val, word, ha='center', va='bottom', fontsize=9)

                plt.xticks([])
                plt.xlabel("Words")
                plt.ylabel(f"Dimension {dim}")
                plt.title("Word-wise Encodings: Mean and 95% CI (1D)")
                plt.grid(True, axis='y')
                plt.tight_layout()
                plt.show()

            else:
                raise ValueError("dims must contain 1 or 2 dimension indices.")

    # Print overall statistics
    print("\nOverall Statistics:")
    for dim in stats:
        print(f"{dim.capitalize()} - Mean: {np.mean(stats[dim]):.4f}, Std: {np.std(stats[dim]):.4f}, Max: {np.max(stats[dim]):.4f}, Min: {np.min(stats[dim]):.4f}")


if __name__ == '__main__':
    act = 'tanh'
    #act = 'gelu'

    source = 'cf'

    #print('Test')
    #analyze_runs(f'results/experiments/gradiend/arousal/distilbert-base-cased/arousal_tanh_supervised_{source}', plot=True)

    #exit()



    print('Train')
    analyze_runs(f'results/experiments/gradiend/emotion-3/distilbert-base-cased/{act}_3_bin_supervised_{source}', key='train')

    print('Test')
    analyze_runs(f'results/experiments/gradiend/emotion-3/distilbert-base-cased/{act}_3_supervised_{source}', key='main')
