import sys
import os
import json
import numpy as np
import geopandas as gpd
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize, LinearSegmentedColormap
from matplotlib.cm import ScalarMappable
# from matplotlib.cm import get_cmap
import statistics

# Set up paths
sys.path.append(os.getcwd())
sys.path.append("/home/name/spatial/LLaVA_interp/")
sys.path.append('/home/name/mambaforge/lib/python3.10/site-packages')

plt.rcParams['font.family'] = 'Arial'

# Load necessary data
with open('supported_countries.json', 'r') as f:
    geopandas_countries = json.load(f)

with open('lang_code_to_full_name.json', 'r') as f:
    lang_code_to_full_name = json.load(f)

with open('language_map_cia.json', 'r') as f:
    language_map_cia = json.load(f)

# Load the new data
with open('multilingual_ambiguity_raw.json') as f:
    data = json.load(f)

# Calculate mean ratios and standard deviations
mean_ratios = {}
std_ratios = {}
for lang_code, values in data.items():
    lang_code = lang_code.upper()  # Always set lang_code to uppercase
    camera3 = values['camera3']
    reference3 = values['reference3']
    
    ratios = [c / r for c, r in zip(camera3, reference3)]
    
    mean_ratio = statistics.mean(ratios)
    std_ratio = statistics.stdev(ratios)
    mean_ratios[lang_code] = mean_ratio
    std_ratios[lang_code] = std_ratio

# Function to find language code by name
def find_key_by_language(data, language):
    for key, value in data.items():
        if value["name"].lower() == language.lower():
            return key.upper()  # Return uppercase key
    return None

# Load world geometry
world = gpd.read_file(gpd.datasets.get_path('naturalearth_lowres'))
world = world[(world['name'] != 'Antarctica') & (world['name'] != 'Fr. S. Antarctic Lands')]

# Prepare data for merging
country_data = []
max_mean_ratio = max(mean_ratios.values())
min_mean_ratio = min(mean_ratios.values())
max_std_ratio = max(std_ratios.values())

for country in geopandas_countries:
    if country in language_map_cia:
        languages_dict_of_country = language_map_cia[country]['languages']
        total_mean_ratio = 0
        total_std_ratio = 0
        total_ratio = 0
        
        for lang, ratio in languages_dict_of_country.items():
            lang_code = find_key_by_language(lang_code_to_full_name, lang)
            if lang_code and lang_code in mean_ratios:
                total_mean_ratio += mean_ratios[lang_code] * ratio
                total_std_ratio += std_ratios[lang_code] * ratio
                total_ratio += ratio
        
        if total_ratio > 0:
            mean_ratio = total_mean_ratio / total_ratio
            std_ratio = total_std_ratio / total_ratio
            country_data.append({
                'country': country,
                'mean_ratio': mean_ratio,
                'std_ratio': std_ratio
            })
        else:
            print(f"No data for country: {country}")
    else:
        print(f"Country not in language_map_cia: {country}")

# Merge data with world geometry
df_countries = pd.DataFrame(country_data)
world = world.merge(df_countries, left_on='name', right_on='country', how='left')

# Normalize the data
# Normalize the data
norm_mean = Normalize(vmin=min_mean_ratio, vmax=max_mean_ratio)
norm_std = Normalize(vmin=0, vmax=max_std_ratio)

# Create custom colormap (using one of the previously suggested color schemes)
n_bins = 100
colors = ['#0000FF', '#FF0000']  # Vibrant Blue to Vibrant Red
# cmap = LinearSegmentedColormap.from_list('custom', colors, N=n_bins)
cmap = plt.get_cmap('viridis_r')

# Plot
fig, ax = plt.subplots(1, 1, figsize=(20, 10))
fig.patch.set_facecolor('#e6e8ec')  # Light grey background


for idx, row in world.iterrows():
    if pd.notnull(row['mean_ratio']) and pd.notnull(row['std_ratio']):
        rgb = cmap(norm_mean(row['mean_ratio']))[:3]
        # Apply non-linear transformation to alpha for more noticeable effect
        alpha = 0.3 + 0.7 * (1 - np.power(norm_std(row['std_ratio']), 0.3))
        color = rgb + (alpha,)
    else:
        color = 'lightgrey'  # Default color for countries with missing data
    gpd.GeoSeries([row['geometry']]).plot(ax=ax, color=color, edgecolor='white', linewidth=0)


# # # Create custom colormap
# # colors = ['#FFA500', '#800080']  # Orange to Purple
# n_bins = 100
# colors = ['#0000FF', '#FF0000']  # Blue to Red
# colors = ['#FFFF00', '#008080']  # Yellow to Teal
# colors = ['#F5F5DC', '#006400']  # Beige to Dark Green
# colors = ['#ADD8E6', '#8B0000']  # Light Blue to Dark Red
# colors = ['#DC143C', '#32CD32']  # Crimson to Lime Green

# # cmap = LinearSegmentedColormap.from_list('custom', colors, N=n_bins)
# cmap = LinearSegmentedColormap.from_list('custom', colors, N=n_bins)
# # cmap = plt.get_cmap('viridis_r')

# # Plot
# fig, ax = plt.subplots(1, 1, figsize=(20, 10))
# fig.patch.set_facecolor('#e6e8ec')

# for idx, row in world.iterrows():
#     if pd.notnull(row['mean_ratio']) and pd.notnull(row['std_ratio']):
#         rgb = cmap(norm_mean(row['mean_ratio']))[:3]
#         alpha = 1 - norm_std(row['std_ratio'])  # Invert alpha so higher std is more transparent
#         color = rgb + (alpha,)
#     else:
#         color = 'lightgrey'  # Default color for countries with missing data
#     gpd.GeoSeries([row['geometry']]).plot(ax=ax, color=color, edgecolor='white', linewidth=0)

ax.set_axis_off()
plt.tight_layout()

# Create custom colorbar for mean ratios
sm_mean = ScalarMappable(cmap=cmap, norm=norm_mean)
sm_mean.set_array([])
cbar_mean = fig.colorbar(sm_mean, ax=ax, orientation='horizontal', location='top', fraction=0.02, pad=0.01)
cbar_mean.set_label('Preference towards reference frame over camera frame', fontsize=20, labelpad=10)
cbar_mean.set_ticks([min_mean_ratio, max_mean_ratio])
cbar_mean.set_ticklabels([f'{min_mean_ratio:.2f}', f'{max_mean_ratio:.2f}'], fontdict={'fontsize': 12})
cbar_mean.ax.xaxis.set_tick_params(labelsize=12, direction='out', pad=5)
cbar_mean.ax.xaxis.set_ticks_position('bottom')

# # Add a note about the meaning of the color scale and transparency
# plt.text(0.5, 0.95, 'Orange indicates preference for camera frame, Purple for reference frame', 
#          horizontalalignment='center', verticalalignment='center', transform=ax.transAxes, fontsize=12)
# plt.text(0.5, 0.92, 'More transparent colors indicate higher standard deviation', 
#          horizontalalignment='center', verticalalignment='center', transform=ax.transAxes, fontsize=12)

plt.savefig('world_map_mean_ratio_with_std_alpha.png', bbox_inches='tight', dpi=300)

# Print sorted mean ratios
sorted_mean_ratios = {k: v for k, v in sorted(mean_ratios.items(), key=lambda item: item[1], reverse=True)}
print("\nMean Ratios (Descending Order):")
for lang_code, mean_ratio in sorted_mean_ratios.items():
    print(f"{lang_code}: {mean_ratio}")