import numpy as np
import geopandas as gpd
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

# Load the built-in dataset of world geometries
world = gpd.read_file(gpd.datasets.get_path("naturalearth_lowres"))

# Data for RGB values per language, extracted from your provided data
rgb_data = {
    "AR": {"mixed": 0.338, "not_rotated": 0.462, "rotated": 0.622},
    "BG": {"mixed": 0.490, "not_rotated": 0.603, "rotated": 0.687},
    "CS": {"mixed": 0.484, "not_rotated": 0.546, "rotated": 0.729},
    "DA": {"mixed": 0.562, "not_rotated": 0.680, "rotated": 0.691},
    "DE": {"mixed": 0.392, "not_rotated": 0.525, "rotated": 0.647},
    "EL": {"mixed": 0.555, "not_rotated": 0.677, "rotated": 0.685},
    "EN-GB": {"mixed": 0.393, "not_rotated": 0.543, "rotated": 0.635},
    "EN-US": {"mixed": 0.359, "not_rotated": 0.498, "rotated": 0.608},
    "ES": {"mixed": 0.318, "not_rotated": 0.465, "rotated": 0.619},
    "ET": {"mixed": 0.579, "not_rotated": 0.685, "rotated": 0.662},
    "FI": {"mixed": 0.561, "not_rotated": 0.660, "rotated": 0.705},
    "FR": {"mixed": 0.389, "not_rotated": 0.519, "rotated": 0.619},
    "HU": {"mixed": 0.674, "not_rotated": 0.742, "rotated": 0.772},
    "ID": {},  # No data
    "IT": {"mixed": 0.330, "not_rotated": 0.473, "rotated": 0.620},
    "JA": {},  # No data
    "KO": {"mixed": 0.411, "not_rotated": 0.505, "rotated": 0.685},
    "LT": {"mixed": 0.470, "not_rotated": 0.559, "rotated": 0.701},
    "LV": {"mixed": 0.435, "not_rotated": 0.542, "rotated": 0.686},
    "NB": {"mixed": 0.415, "not_rotated": 0.546, "rotated": 0.662},
    "NL": {"mixed": 0.866, "not_rotated": 0.866, "rotated": 0.866},
    "PL": {"mixed": 0.544, "not_rotated": 0.649, "rotated": 0.728},
    "PT-BR": {"mixed": 0.312, "not_rotated": 0.445, "rotated": 0.630},
    "PT-PT": {"mixed": 0.377, "not_rotated": 0.532, "rotated": 0.627},
    "RO": {"mixed": 0.374, "not_rotated": 0.516, "rotated": 0.650},
    "RU": {"mixed": 0.352, "not_rotated": 0.441, "rotated": 0.671},
    "SK": {"mixed": 0.404, "not_rotated": 0.519, "rotated": 0.652},
    "SL": {"mixed": 0.504, "not_rotated": 0.546, "rotated": 0.707},
    "SV": {"mixed": 0.391, "not_rotated": 0.534, "rotated": 0.649},
    "TR": {"mixed": 0.462, "not_rotated": 0.560, "rotated": 0.666},
    "UK": {},  # No data
    "ZH": {}   # No data
}

# List of supported languages
supported_languages = list(rgb_data.keys())

# # Generate unique colors for each language based on provided data
# Exponential scaling function to enhance color contrasts
def exponential_scale(value):
    # Scale the value exponentially within the range [0, 1]
    # Adjust the base of the exponentiation to control contrast
    return np.power(value, 0.5)  # Using square root as an example

# Modify color_dict as per the new requirement using exponential scaling
color_dict = {}
for lang, values in rgb_data.items():
    if values:
        color = [
            exponential_scale(values["mixed"]),
            exponential_scale(values["not_rotated"]),
            exponential_scale(values["rotated"])
        ]
        color_dict[lang] = tuple(color)


# color_dict = {lang: (data["mixed"], data["not_rotated"], data["rotated"]) for lang, data in rgb_data.items() if data}

# # Modify color_dict as per the new requirement
# color_dict = {}
# for lang, values in rgb_data.items():
#     if values:
#         max_key = max(values, key=values.get)  # Find the key with the maximum value
#         color = [0, 0, 0]  # Default to black
#         if max_key == "mixed":
#             color[0] = values["mixed"]  # Assign to Red
#         elif max_key == "not_rotated":
#             color[1] = values["not_rotated"]  # Assign to Green
#         elif max_key == "rotated":
#             color[2] = values["rotated"]  # Assign to Blue
#         color_dict[lang] = tuple(color)

# Flatten the RGB data to find the min and max for scaling
# all_values = [value for data in rgb_data.values() for value in data.values() if data]
# min_val, max_val = min(all_values), max(all_values)

# # Scale the RGB values to enhance contrast
# def scale_value(val):
#     # Scale the value to use the full color range [0, 1]
#     return (val - min_val) / (max_val - min_val)

# color_dict = {
#     lang: (
#         scale_value(data["mixed"]),
#         scale_value(data["not_rotated"]),
#         scale_value(data["rotated"])
#     ) for lang, data in rgb_data.items() if data
# }

# Mapping of languages to countries (in ISO alpha-3 code)
language_to_country = {
    "EN-GB": ["GBR"], "EN-US": ["USA"], "FR": ["FRA"], "DE": ["DEU"], "ES": ["ESP"], 
    "RU": ["RUS"], "ZH": ["CHN"], "JA": ["JPN"], "IT": ["ITA"], "AR": ["SAU", "DZA", "EGY"],
    "BG": ["BGR"], "CS": ["CZE"], "DA": ["DNK"], "EL": ["GRC"], "ET": ["EST"], 
    "FI": ["FIN"], "HU": ["HUN"], "ID": ["IDN"], "KO": ["KOR"], "LT": ["LTU"], 
    "LV": ["LVA"], "NB": ["NOR"], "NL": ["NLD"], "PL": ["POL"], "PT-BR": ["BRA"], 
    "PT-PT": ["PRT"], "RO": ["ROU"], "SK": ["SVK"], "SL": ["SVN"], "SV": ["SWE"], 
    "TR": ["TUR"], "UK": ["UKR"]
}

# Flatten the dictionary to map ISO codes directly to colors
country_to_language = {iso: lang for lang, isos in language_to_country.items() for iso in isos}
world['color'] = world['iso_a3'].apply(lambda x: color_dict.get(country_to_language.get(x), (1, 1, 1)))  # default to white

# Plotting the world map
fig, ax = plt.subplots(1, 1, figsize=(20, 10))
world.plot(ax=ax, color=[mcolors.rgb2hex(c) for c in world['color']])

# To emphasize boundaries, add this line:
world.boundary.plot(ax=ax, linewidth=1, edgecolor="black")

# Create a legend
from matplotlib.patches import Patch
legend_labels = [Patch(facecolor=color_dict[lang], label=lang) for lang in supported_languages if lang in color_dict]
plt.legend(handles=legend_labels, title="Languages", bbox_to_anchor=(1.05, 1), loc="upper left", fontsize=8)

# Additional settings
ax.set_title("World Languages Map", fontsize=15)
ax.set_axis_off()

# Save the plot to a file
plt.savefig("/home/name/spatial/LLaVA_interp/data_generation/colored_world_map.png", bbox_inches="tight")
plt.close()