import geopandas as gpd
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

# Load the built-in dataset of world geometries
world = gpd.read_file(gpd.datasets.get_path('naturalearth_lowres'))

# List of supported languages
supported_languages = ["AR", "BG", "CS", "DA", "DE", "EL", "EN-GB", "EN-US", "ES", "ET", 
                       "FI", "FR", "HU", "ID", "IT", "JA", "KO", "LT", "LV", "NB", "NL", 
                       "PL", "PT-BR", "PT-PT", "RO", "RU", "SK", "SL", "SV", "TR", "UK", "ZH"]

# Generate unique colors for each language using a colormap
colors = plt.get_cmap('nipy_spectral', len(supported_languages))
color_dict = {lang: colors(i) for i, lang in enumerate(supported_languages)}

# Mapping of languages to countries (in ISO alpha-3 code)
language_to_country = {
    "EN-GB": ["GBR"], "EN-US": ["USA"], "FR": ["FRA"], "DE": ["DEU"], "ES": ["ESP"], 
    "RU": ["RUS"], "ZH": ["CHN"], "JA": ["JPN"], "IT": ["ITA"], "AR": ["SAU", "DZA", "EGY"],
    "BG": ["BGR"], "CS": ["CZE"], "DA": ["DNK"], "EL": ["GRC"], "ET": ["EST"], 
    "FI": ["FIN"], "HU": ["HUN"], "ID": ["IDN"], "KO": ["KOR"], "LT": ["LTU"], 
    "LV": ["LVA"], "NB": ["NOR"], "NL": ["NLD"], "PL": ["POL"], "PT-BR": ["BRA"], 
    "PT-PT": ["PRT"], "RO": ["ROU"], "SK": ["SVK"], "SL": ["SVN"], "SV": ["SWE"], 
    "TR": ["TUR"], "UK": ["UKR"]
}

# Flatten the dictionary to map ISO codes directly to colors
country_to_language = {iso: lang for lang, isos in language_to_country.items() for iso in isos}
world['color'] = world['iso_a3'].map(lambda x: color_dict.get(country_to_language.get(x), 'white'))

# Plotting the world map
fig, ax = plt.subplots(1, 1, figsize=(20, 10))
world.plot(ax=ax, color=world['color'])

# To emphasize boundaries, add this line:
world.boundary.plot(ax=ax, linewidth=1, edgecolor='black')

# Create a legend
from matplotlib.patches import Patch
legend_labels = [Patch(facecolor=color_dict[lang], label=lang) for lang in supported_languages if lang in color_dict]
plt.legend(handles=legend_labels, title='Languages', bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8)

# Additional settings
ax.set_title('World Languages Map', fontsize=15)
ax.set_axis_off()


# Save the plot to a file
plt.savefig('/home/name/spatial/LLaVA_interp/data_generation/colored_world_map.png', bbox_inches='tight')
plt.close()