import sys
import json
import numpy as np
from language_map import DEEPL, GOOGLE_TRANSLATE
from supported_countries import supported_countries as geopandas_countries


import sys
import os
sys.path.append(os.getcwd())
sys.path.append("/home/name/spatial/LLaVA_interp/")
sys.path.append('/home/name/mambaforge/lib/python3.10/site-packages')
import geopandas as gpd
import json
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
# from data_generation.world_map.language_map import language_map
from matplotlib.colors import Normalize
from matplotlib.cm import ScalarMappable
from matplotlib.colorbar import ColorbarBase
plt.rcParams['font.family'] = 'Arial'
# Countries supported on the map
with open('supported_countries.json', 'w') as f:
   json.dump(sorted(geopandas_countries), f, indent=4)

# OUR DATA
# mapping language codes to full names (from our data)
google_languages = {value.upper(): {"name": key} for key, value in GOOGLE_TRANSLATE.items()}
lang_code_to_full_name = {**{code: {"name": name} for code, name in DEEPL.items()}, **google_languages}
with open('lang_code_to_full_name.json', 'w') as f:
    json.dump(lang_code_to_full_name, f, indent=4)

all_languages_from_our_data = []
for lang_code in lang_code_to_full_name:
    all_languages_from_our_data.append(lang_code_to_full_name[lang_code]
    ['name'])


# CIA lookup
# load language_map_cia.json database
with open('language_map_cia.json') as f:
    language_map_cia = json.load(f)

all_countries_from_cia = []
for country in language_map_cia:
    all_countries_from_cia.append(country)

all_languages_from_cia = []
for country in language_map_cia:
    for language in language_map_cia[country]['languages']:
        all_languages_from_cia.append(language)

# save all_languages_from_cia to a file
with open('all_languages_from_cia.json', 'w') as f:
    json.dump(all_languages_from_cia, f, indent=4)

# save all_countries_from_cia to a file
with open('all_countries_from_cia.json', 'w') as f:
    json.dump(all_countries_from_cia, f, indent=4)

print(all_countries_from_cia)
print()
print(all_languages_from_cia)
print()

missing_countries = []
missing_languages = []
for country in sorted(geopandas_countries):
    if country in all_countries_from_cia:
        languages = language_map_cia[country]['languages']
        for lang_code in languages.keys():
            if lang_code not in all_languages_from_our_data:
                missing_languages.append(lang_code)
    else:
        missing_countries.append(country)


print("Missing countries, safe to ignore:", missing_countries)
print()
print("Missing languages:", missing_languages)


missing_languages2 = []
for language in all_languages_from_our_data:
    if language not in all_languages_from_cia:
        # print(language)
        missing_languages2.append(language)

print()
print("Missing languages2:", missing_languages2, len(missing_languages2))

"""
For all countries on the map:
    For all languages in the country:
        If the language is not in the languages_with_data list:
            Print the language code and the country
"""

with open('multilingual_conventions.json') as f:
    data = json.load(f)
    data = data['camera3']

def find_key_by_language(data, language):
    for key, value in data.items():
        print(key, value, language)
        if value["name"] == language:
            return key #.upper()
    return None


# Load world geometry from GeoPandas dataset
world = gpd.read_file(gpd.datasets.get_path('naturalearth_lowres'))
world = world[world['name'] != 'Antarctica']
world = world[world['name'] != 'Fr. S. Antarctic Lands']
# Prepare data for merging
country_data = []
max_ratio = 0  # Track the maximum ratio for normalization
min_ratio = float('inf')  # Track the minimum ratio for normalization
max_country = None
min_country = None
max_mixed = 0
min_mixed = float('inf')
max_not_rotated = 0
min_not_rotated = float('inf')
max_rotated = 0
min_rotated = float('inf')


for country in geopandas_countries:
    if country in language_map_cia:
        languages_dict_of_country = language_map_cia[country]['languages']
        languages_of_country = list(languages_dict_of_country.keys())
        ratios_of_country = list(languages_dict_of_country.values())
        
        total_ratio = 0
        total_cam_ref = 0
        total_mixed = 0
        total_not_rotated = 0
        total_rotated = 0
        for (lang, ratio) in zip(languages_of_country, ratios_of_country):
            # full lang to lang code
            lang_code = find_key_by_language(lang_code_to_full_name, lang)
            if country == "India":
                print(lang, lang_code)
                # sys.exit(0)
            if lang_code is None:
                print(f"Skipping {lang} in {country} 1")
                continue
            # lang_code = lang_code.upper()   
            if data.get(lang_code.lower(), None) is None and data.get(lang_code.upper(), None) is None:
                print(f"Skipping {lang} in {country} 2")
                continue
            else:
                if data.get(lang_code.lower(), None) is None:
                    lang_code = lang_code.upper()
                else:
                    lang_code = lang_code.lower()
            mixed = data[lang_code]['mixed']
            not_rotated = data[lang_code]['not_rotated']
            rotated = data[lang_code]['rotated']
            
            total_mixed += mixed * ratio
            total_not_rotated += not_rotated * ratio
            total_rotated += rotated * ratio
            total_ratio += ratio
        # if total_ratio == 0:
        #     continue

        total_mixed /= total_ratio
        total_not_rotated /= total_ratio
        total_rotated /= total_ratio
        
        max_mixed = max(max_mixed, total_mixed)
        min_mixed = min(min_mixed, total_mixed)
        max_not_rotated = max(max_not_rotated, total_not_rotated)
        min_not_rotated = min(min_not_rotated, total_not_rotated)
        max_rotated = max(max_rotated, total_rotated)
        min_rotated = min(min_rotated, total_rotated)
        
        country_data.append({
            'country': country,
            'mixed': total_mixed,
            'not_rotated': total_not_rotated,
            'rotated': total_rotated
        })


# Check if 'country' column exists in the DataFrame
df_countries = pd.DataFrame(country_data)
if 'country' not in df_countries.columns:
    raise KeyError("'country' column not found in df_countries DataFrame")

# Merge with world data
world = world.merge(df_countries, left_on='name', right_on='country', how='left')

# Normalize the values to [0, 1]
def normalize(value, min_value, max_value):
    return (value - min_value) / (max_value - min_value) if max_value != min_value else 0

# for data in country_data:
#     data['color'] = (
#         normalize(data['mixed'], min_mixed, max_mixed),
#         normalize(data['rotated'], min_rotated, max_rotated),
#         normalize(data['not_rotated'], min_not_rotated, max_not_rotated)
#     )

for data in country_data:
    data['color'] = (
        normalize(data['mixed'], min(min_mixed, min_rotated, min_not_rotated), max(max_mixed, max_rotated, max_not_rotated)),
        normalize(data['rotated'], min(min_mixed, min_rotated, min_not_rotated), max(max_mixed, max_rotated, max_not_rotated)),
        normalize(data['not_rotated'], min(min_mixed, min_rotated, min_not_rotated), max(max_mixed, max_rotated, max_not_rotated)),
    )


# for data in country_data:
#     data['color'] = (
#         data['mixed'],
#         data['rotated'],
#         data['not_rotated'],
#     )


# Check if 'country' column exists in the DataFrame
df_countries = pd.DataFrame(country_data)
if 'country' not in df_countries.columns:
    raise KeyError("'country' column not found in df_countries DataFrame")

# Merge with world data
world = world.merge(df_countries, left_on='name', right_on='country', how='left')

fig, ax = plt.subplots(1, 1, figsize=(20, 10))
fig.patch.set_facecolor('#e6e8ec')  # Light grey background

for_polygons = world.dropna(subset=['color'])

# Plot each country with its corresponding RGB color
for_polygons.plot(ax=ax, color=for_polygons['color'].apply(lambda x: tuple(x)), edgecolor='white', linewidth=0.5)

# Reduce the margins
ax.margins(x=0.01)  # Set x margins to a low value

# Automatically adjust the subplot parameters
plt.tight_layout()

ax.set_axis_off()
plt.title('Conventions in Frame of Reference', fontsize=20)
plt.savefig('world_map_convention.png', bbox_inches='tight')

print("Missing countries, safe to ignore:", set(geopandas_countries) - set(df_countries['country']))

# Save country data sorted by mixed ratio
country_data.sort(key=lambda x: x['mixed'], reverse=True)
with open('country_data_convention_sorted.json', 'w') as f:
    json.dump(country_data, f, indent=4)