import sys
import os
import json
import numpy as np
import geopandas as gpd
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize
from matplotlib.cm import ScalarMappable
import statistics

# Set up paths
sys.path.append(os.getcwd())
sys.path.append("/home/name/spatial/LLaVA_interp/")
sys.path.append('/home/name/mambaforge/lib/python3.10/site-packages')

plt.rcParams['font.family'] = 'Arial'

# Load necessary data
with open('supported_countries.json', 'r') as f:
    geopandas_countries = json.load(f)

with open('lang_code_to_full_name.json', 'r') as f:
    lang_code_to_full_name = json.load(f)

with open('language_map_cia.json', 'r') as f:
    language_map_cia = json.load(f)

# Load the new data
with open('multilingual_ambiguity_raw.json') as f:
    data = json.load(f)

# Calculate standard deviations
std_ratios = {}
for lang_code, values in data.items():
    lang_code = lang_code.upper()  # Always set lang_code to uppercase
    camera3 = values['camera3']
    reference3 = values['reference3']
    
    ratios = [c / r for c, r in zip(camera3, reference3)]
    
    std_ratio = statistics.stdev(ratios)
    std_ratios[lang_code] = std_ratio

# Function to find language code by name
def find_key_by_language(data, language):
    for key, value in data.items():
        if value["name"].lower() == language.lower():
            return key.upper()  # Return uppercase key
    return None

# Load world geometry
world = gpd.read_file(gpd.datasets.get_path('naturalearth_lowres'))
world = world[(world['name'] != 'Antarctica') & (world['name'] != 'Fr. S. Antarctic Lands')]

# Prepare data for merging
country_data = []
max_std_ratio = max(std_ratios.values())

for country in geopandas_countries:
    if country in language_map_cia:
        languages_dict_of_country = language_map_cia[country]['languages']
        total_std_ratio = 0
        total_ratio = 0
        
        for lang, ratio in languages_dict_of_country.items():
            lang_code = find_key_by_language(lang_code_to_full_name, lang)
            if lang_code and lang_code in std_ratios:
                total_std_ratio += std_ratios[lang_code] * ratio
                total_ratio += ratio
        
        if total_ratio > 0:
            std_ratio = total_std_ratio / total_ratio
            country_data.append({
                'country': country,
                'std_ratio': std_ratio
            })
        else:
            print(f"No data for country: {country}")
    else:
        print(f"Country not in language_map_cia: {country}")

# Merge data with world geometry
df_countries = pd.DataFrame(country_data)
world = world.merge(df_countries, left_on='name', right_on='country', how='left')

# Normalize the data
norm_std = Normalize(vmin=0, vmax=max_std_ratio)

# Plot
fig, ax = plt.subplots(1, 1, figsize=(20, 10))
fig.patch.set_facecolor('#e6e8ec')

cmap = plt.get_cmap('viridis')

for idx, row in world.iterrows():
    if pd.notnull(row['std_ratio']):
        color = cmap(norm_std(row['std_ratio']))
    else:
        color = 'lightgrey'  # Default color for countries with missing data
    gpd.GeoSeries([row['geometry']]).plot(ax=ax, color=color, edgecolor='white', linewidth=0.5)

ax.set_axis_off()
plt.tight_layout()

# Create custom colorbar
sm = ScalarMappable(cmap=cmap, norm=norm_std)
sm.set_array([])
cbar = fig.colorbar(sm, ax=ax, orientation='horizontal', location='bottom', fraction=0.05, pad=0.1)
cbar.set_label('Standard Deviation of Camera/Reference Frame Preference', fontsize=20, labelpad=20)
cbar.ax.tick_params(labelsize=12)

plt.title('World Map of Standard Deviations in Camera/Reference Frame Preference', fontsize=24, pad=20)

plt.savefig('world_map_std_dev.png', bbox_inches='tight', dpi=300)

# Print sorted standard deviations
sorted_std_ratios = {k: v for k, v in sorted(std_ratios.items(), key=lambda item: item[1], reverse=True)}
print("\nStandard Deviations of Ratios (Descending Order):")
for lang_code, std_ratio in sorted_std_ratios.items():
    print(f"{lang_code}: {std_ratio}")