import sys
import json
import numpy as np
from language_map import DEEPL, GOOGLE_TRANSLATE
from supported_countries import supported_countries as geopandas_countries


import sys
import os
sys.path.append(os.getcwd())
sys.path.append("/home/name/spatial/LLaVA_interp/")
sys.path.append('/home/name/mambaforge/lib/python3.10/site-packages')
import geopandas as gpd
import json
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
# from data_generation.world_map.language_map import language_map
from matplotlib.colors import Normalize
from matplotlib.cm import ScalarMappable
from matplotlib.colorbar import ColorbarBase
plt.rcParams['font.family'] = 'Arial'
# Countries supported on the map
with open('supported_countries.json', 'w') as f:
   json.dump(sorted(geopandas_countries), f, indent=4)

# OUR DATA
# mapping language codes to full names (from our data)
google_languages = {value.upper(): {"name": key} for key, value in GOOGLE_TRANSLATE.items()}
lang_code_to_full_name = {**{code: {"name": name} for code, name in DEEPL.items()}, **google_languages}
with open('lang_code_to_full_name.json', 'w') as f:
    json.dump(lang_code_to_full_name, f, indent=4)

all_languages_from_our_data = []
for lang_code in lang_code_to_full_name:
    all_languages_from_our_data.append(lang_code_to_full_name[lang_code]
    ['name'])


# CIA lookup
# load language_map_cia.json database
with open('language_map_cia.json') as f:
    language_map_cia = json.load(f)

all_countries_from_cia = []
for country in language_map_cia:
    all_countries_from_cia.append(country)

all_languages_from_cia = []
for country in language_map_cia:
    for language in language_map_cia[country]['languages']:
        all_languages_from_cia.append(language)

# save all_languages_from_cia to a file
with open('all_languages_from_cia.json', 'w') as f:
    json.dump(all_languages_from_cia, f, indent=4)

# save all_countries_from_cia to a file
with open('all_countries_from_cia.json', 'w') as f:
    json.dump(all_countries_from_cia, f, indent=4)

print(all_countries_from_cia)
print()
print(all_languages_from_cia)
print()
# sys.exit(0)
# print(sorted(languages_with_data))

missing_countries = []
missing_languages = []
for country in sorted(geopandas_countries):
    if country in all_countries_from_cia:
        # print(country)
        # print(language_map_cia[country])
        languages = language_map_cia[country]['languages']

        # languages = list(country.values())
        # print(languages)
        # languages = languages[0]['languages']
        for lang_code in languages.keys():
            if lang_code not in all_languages_from_our_data:
                # print(f"Missing lang_code: {lang_code} for country: {country}") # {list(country.keys())[0]}
                missing_languages.append(lang_code)
    else:
        # print(f"Missing country: {country}")
        missing_countries.append(country)


print("Missing countries, safe to ignore:", missing_countries)
print()
print("Missing languages:", missing_languages)


missing_languages2 = []
for language in all_languages_from_our_data:
    if language not in all_languages_from_cia:
        # print(language)
        missing_languages2.append(language)

print()
print("Missing languages2:", missing_languages2, len(missing_languages2))

"""
For all countries on the map:
    For all languages in the country:
        If the language is not in the languages_with_data list:
            Print the language code and the country
"""

with open('multilingual_ambiguity_all.json') as f:
    data = json.load(f)

def find_key_by_language(data, language):
    for key, value in data.items():
        print(key, value, language)
        if value["name"] == language:
            return key #.upper()
    return None


# Load world geometry from GeoPandas dataset
world = gpd.read_file(gpd.datasets.get_path('naturalearth_lowres'))
world = world[world['name'] != 'Antarctica']
world = world[world['name'] != 'Fr. S. Antarctic Lands']
# Prepare data for merging
country_data = []
max_ratio = 0  # Track the maximum ratio for normalization
min_ratio = float('inf')  # Track the minimum ratio for normalization
max_country = None
min_country = None

for country in geopandas_countries:
    if country in language_map_cia:
        languages_dict_of_country = language_map_cia[country]['languages']
        languages_of_country = list(languages_dict_of_country.keys())
        ratios_of_country = list(languages_dict_of_country.values())
        
        total_ratio = 0
        total_cam_ref = 0
        for (lang, ratio) in zip(languages_of_country, ratios_of_country):
            # full lang to lang code
            lang_code = find_key_by_language(lang_code_to_full_name, lang)
            if country == "India":
                print(lang, lang_code)
                # sys.exit(0)
            if lang_code is None:
                print(f"Skipping {lang} in {country} 1")
                continue
            # lang_code = lang_code.upper()   
            if data.get(lang_code.lower(), None) is None and data.get(lang_code.upper(), None) is None:
                print(f"Skipping {lang} in {country} 2")
                continue
            else:
                if data.get(lang_code.lower(), None) is None:
                    lang_code = lang_code.upper()
                else:
                    lang_code = lang_code.lower()
            total_cam_ref += data[lang_code]['camera3'] / data[lang_code]['reference3'] * ratio
            total_ratio += ratio

            if total_ratio == 0:
                print(country)
        total_cam_ref /= total_ratio

        if max_ratio < total_cam_ref:
            max_country = country
        
        if min_ratio > total_cam_ref:
            min_country = country

        max_ratio = max(max_ratio, total_cam_ref)
        min_ratio = min(min_ratio, total_cam_ref)
        country_data.append({
                    'country': country,
                    'ratio': total_cam_ref
                })
        


# Assuming 'country_data' and 'world' are already defined and appropriately prepared
df_countries = pd.DataFrame(country_data)
world = world.merge(df_countries, left_on='name', right_on='country', how='left')
norm = Normalize(vmin=min_ratio, vmax=max_ratio)
cmap = plt.get_cmap('viridis_r')  # Use a heatmap-like colormap

# a list of colormap colors
# viridis, inferno, plasma, magma, cividis, seismic, coolwarm, bwr, viridis_r, inferno_r, plasma_r, magma_r, cividis_r, seismic_r, coolwarm_r, bwr_r

fig, ax = plt.subplots(1, 1, figsize=(20, 10))
fig.patch.set_facecolor('#e6e8ec')  # Light grey background

for_polygons = world.dropna(subset=['ratio'])
for_polygons.plot(ax=ax, color=cmap(norm(for_polygons['ratio'])), edgecolor='white', linewidth=0)
# make without boundary


# Reduce the margins
ax.margins(x=0.01)  # Set x margins to a low value

# Automatically adjust the subplot parameters
plt.tight_layout()

# # Plotting countries without data in light grey
# countries_no_data = world[world['ratio'].isna()]
# countries_no_data.plot(ax=ax, color='#d0d0d0', edgecolor='white', linewidth=0.2)

# Adding names only to the grey countries
# for idx, row in countries_no_data.iterrows():
#     x, y = row['geometry'].centroid.x, row['geometry'].centroid.y
#     ax.annotate(row['name'], xy=(x, y), horizontalalignment='center', verticalalignment='center', fontsize=6)

ax.set_axis_off()


sm = ScalarMappable(cmap=cmap, norm=norm)
sm.set_array([])
cbar = fig.colorbar(sm, ax=ax, orientation='horizontal', location='top', fraction=0.02, pad=0.01)
cbar.set_label('Preference towards reference frame over camera frame', fontsize=20, labelpad=10)

print(norm.vmax, norm.vmin)
print(max_country, min_country)
cbar.set_ticks([norm.vmin, norm.vmax])
cbar.set_ticklabels([f'{norm.vmin:.1f}', f'{norm.vmax:.1f}'], fontdict={'fontsize': 12})  # Adjust font size for tick labels

# Optionally, adjust the font size of all tick labels directly
cbar.ax.xaxis.set_tick_params(labelsize=12, direction='out', pad=5)
cbar.ax.xaxis.set_ticks_position('bottom')

plt.savefig('world_map_v3.png', bbox_inches='tight')



# # Extract the names of countries without data
# grey_countries = countries_no_data['name'].tolist()

# # Print the list of grey country names
# print("Countries without data:")
# for country in grey_countries:
#     print(country)

print()

print("Missing countries, safe to ignore:", missing_countries)


print()

country_data.sort(key=lambda x: x['ratio'], reverse=True)

print(country_data)

with open('country_data_sorted.json', 'w') as f:
    json.dump(country_data, f, indent=4)

lang_text = ""
for lang in all_languages_from_our_data:
    # join the string with a comma
    lang_text += lang + ", "

print()
print(lang_text)

print(len(all_languages_from_our_data))


all_lang = []
for lang in data.keys():
    # print(lang)
    if lang_code_to_full_name.get(lang.lower(), None) is None and lang_code_to_full_name.get(lang.upper(), None) is None:
        print(f"Skipping {lang} in {country} 2")
        continue
    else:
        if lang_code_to_full_name.get(lang.lower(), None) is None:
            lang = lang.upper()
            
        else:
            lang = lang.lower()
        full_lang = lang_code_to_full_name.get(lang.upper(), None) 
        all_lang.append(full_lang['name'])

lang_text = ""
for lang in all_lang:
    # join the string with a comma
    lang_text += lang + ", "
print(lang_text)
print(len(all_lang))
# Ghana
# Thailand
# Vietnam
# Bangladesh

# Timor-Leste
# Botswana
# Namibia
# Bhutan