import sys
import os
import json
import numpy as np
import geopandas as gpd
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize
from matplotlib.cm import ScalarMappable
import statistics

# Set up paths
sys.path.append(os.getcwd())
sys.path.append("/home/name/spatial/LLaVA_interp/")
sys.path.append('/home/name/mambaforge/lib/python3.10/site-packages')

plt.rcParams['font.family'] = 'Arial'

# Load necessary data
with open('supported_countries.json', 'r') as f:
    geopandas_countries = json.load(f)

with open('lang_code_to_full_name.json', 'r') as f:
    lang_code_to_full_name = json.load(f)

with open('language_map_cia.json', 'r') as f:
    language_map_cia = json.load(f)

# Load the new data
with open('multilingual_ambiguity_raw.json') as f:
    data = json.load(f)

# Calculate mean ratios
mean_ratios = {}
for lang_code, values in data.items():
    lang_code = lang_code.upper()  # Always set lang_code to uppercase
    camera3 = values['camera3']
    reference3 = values['reference3']
    
    ratios = [c / r for c, r in zip(camera3, reference3)]
    
    # import pdb; pdb.set_trace()
    mean_ratio = statistics.mean(ratios)
    mean_ratios[lang_code] = mean_ratio

# Function to find language code by name
def find_key_by_language(data, language):
    for key, value in data.items():
        if value["name"].lower() == language.lower():
            return key.upper()  # Return uppercase key
    return None

# Load world geometry
world = gpd.read_file(gpd.datasets.get_path('naturalearth_lowres'))
world = world[(world['name'] != 'Antarctica') & (world['name'] != 'Fr. S. Antarctic Lands')]

# Ensure the GeoDataFrame has a CRS
world = world.set_crs('EPSG:4326')

# Prepare data for merging
country_data = []
max_mean_ratio = max(mean_ratios.values())
min_mean_ratio = min(mean_ratios.values())

for country in geopandas_countries:
    if country in language_map_cia:
        languages_dict_of_country = language_map_cia[country]['languages']
        total_mean_ratio = 0
        total_ratio = 0
        
        for lang, ratio in languages_dict_of_country.items():
            lang_code = find_key_by_language(lang_code_to_full_name, lang)
            if lang_code and lang_code in mean_ratios:
                total_mean_ratio += mean_ratios[lang_code] * ratio
                total_ratio += ratio
        
        if total_ratio > 0:
            mean_ratio = total_mean_ratio / total_ratio
            country_data.append({
                'country': country,
                'mean_ratio': mean_ratio
            })
        else:
            print(f"No data for country: {country}")
    else:
        print(f"Country not in language_map_cia: {country}")

# Merge data with world geometry
df_countries = pd.DataFrame(country_data)
world = world.merge(df_countries, left_on='name', right_on='country', how='left')

# Normalize the data
norm = Normalize(vmin=min_mean_ratio, vmax=max_mean_ratio)

# Plot
fig, ax = plt.subplots(1, 1, figsize=(14, 7))  # Adjusted size for an elliptical map
fig.patch.set_facecolor('#e6e8ec')

cmap = plt.get_cmap('viridis_r')

# Use Mollweide projection
# original is original
world = world.to_crs('+proj=gall +lon_0=0') # -> ver 1 
# world = world.to_crs('+proj=ortho +lat_0=0 +lon_0=0') # -> ver 2
# world = world.to_crs('+proj=aeqd +lat_0=0 +lon_0=0') # -> ver 3
# world = world.to_crs('+proj=cea +lon_0=0') # -> ver 4


projections = {
    'Orthographic': '+proj=ortho +lat_0=0 +lon_0=0',
    'Azimuthal Equidistant': '+proj=aeqd +lat_0=0 +lon_0=0',
    'Robinson': '+proj=robin',
    'Eckert IV': '+proj=eck4',
    'Winkel Tripel': '+proj=wintri',
    'Mollweide': '+proj=moll +lon_0=0',
    'Gall-Peters': '+proj=cea +lon_0=0'
}



for idx, row in world.iterrows():
    if pd.notnull(row['mean_ratio']):
        color = cmap(norm(row['mean_ratio']))
    else:
        color = 'lightgrey'  # Default color for countries with missing data
    gpd.GeoSeries([row['geometry']]).plot(ax=ax, color=color, edgecolor='white', linewidth=0)

ax.set_axis_off()
plt.tight_layout()

# Create custom colorbar
sm = ScalarMappable(cmap=cmap, norm=norm)
sm.set_array([])
cbar = fig.colorbar(sm, ax=ax, orientation='horizontal', location='top', fraction=0.02, pad=0.01)
cbar.set_label('Preference towards reference frame over camera frame', fontsize=20, labelpad=10)
cbar.set_ticks([min_mean_ratio, max_mean_ratio])
cbar.set_ticklabels([f'{min_mean_ratio:.2f}', f'{max_mean_ratio:.2f}'], fontdict={'fontsize': 12})
cbar.ax.xaxis.set_tick_params(labelsize=12, direction='out', pad=5)
cbar.ax.xaxis.set_ticks_position('bottom')

# # Add a note about the meaning of the color scale
# plt.text(0.5, 0.95, 'Lower values (yellow) indicate preference for camera frame', 
#          horizontalalignment='center', verticalalignment='center', transform=ax.transAxes, fontsize=12)

# plt.savefig('world_map_mean_ratio_elliptical_format.png', bbox_inches='tight', dpi=300)
plt.savefig('world_map_mean_ratio_0714_compact_ver4.png', bbox_inches='tight', dpi=300)

# Print sorted mean ratios
sorted_mean_ratios = {k: v for k, v in sorted(mean_ratios.items(), key=lambda item: item[1], reverse=True)}
print("\nMean Ratios (Descending Order):")
for lang_code, mean_ratio in sorted_mean_ratios.items():
    print(f"{lang_code}: {mean_ratio}")
