import sys
import os
import json
import numpy as np
import geopandas as gpd
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize, rgb_to_hsv, hsv_to_rgb
from matplotlib.cm import ScalarMappable
from matplotlib.colorbar import ColorbarBase
import statistics

# Set up paths
sys.path.append(os.getcwd())
sys.path.append("/home/name/spatial/LLaVA_interp/")
sys.path.append('/home/name/mambaforge/lib/python3.10/site-packages')

plt.rcParams['font.family'] = 'Arial'

# Load necessary data
with open('supported_countries.json', 'r') as f:
    geopandas_countries = json.load(f)

with open('lang_code_to_full_name.json', 'r') as f:
    lang_code_to_full_name = json.load(f)

with open('language_map_cia.json', 'r') as f:
    language_map_cia = json.load(f)

# Load the new data
with open('multilingual_ambiguity_raw.json') as f:
    data = json.load(f)

# Calculate average ratios and standard deviations
average_ratios = {}
std_ratios = {}
for lang_code, values in data.items():
    lang_code = lang_code.upper()
    camera3 = values['camera3']
    reference3 = values['reference3']
    
    ratios = [c / r for c, r in zip(camera3, reference3)]
    
    avg_ratio = sum(ratios) / len(ratios)
    average_ratios[lang_code] = avg_ratio
    
    std_ratio = statistics.stdev(ratios)
    std_ratios[lang_code] = std_ratio

# Function to find language code by name
def find_key_by_language(data, language):
    for key, value in data.items():
        if value["name"] == language:
            return key
    return None

# Load world geometry
world = gpd.read_file(gpd.datasets.get_path('naturalearth_lowres'))
world = world[(world['name'] != 'Antarctica') & (world['name'] != 'Fr. S. Antarctic Lands')]

# Prepare data for merging
country_data = []
max_avg_ratio = float('-inf')
min_avg_ratio = float('inf')
max_std_ratio = float('-inf')

for country in geopandas_countries:
    if country in language_map_cia:
        languages_dict_of_country = language_map_cia[country]['languages']
        total_avg_ratio = 0
        total_std_ratio = 0
        total_ratio = 0
        
        for lang, ratio in languages_dict_of_country.items():
            lang_code = find_key_by_language(lang_code_to_full_name, lang)
            if lang_code is None:
                continue
            lang_code = lang_code.upper()
            if lang_code and lang_code.upper() in average_ratios:
                total_avg_ratio += average_ratios[lang_code.upper()] * ratio
                total_std_ratio += std_ratios[lang_code.upper()] * ratio
                total_ratio += ratio
        
        if total_ratio > 0:
            avg_ratio = total_avg_ratio / total_ratio
            std_ratio = total_std_ratio / total_ratio
            max_avg_ratio = max(max_avg_ratio, avg_ratio)
            min_avg_ratio = min(min_avg_ratio, avg_ratio)
            max_std_ratio = max(max_std_ratio, std_ratio)
            country_data.append({
                'country': country,
                'avg_ratio': avg_ratio,
                'std_ratio': std_ratio
            })
        else:
            print(f"No data for country: {country}")
    else:
        print(f"Country not in language_map_cia: {country}")

# Merge data with world geometry
df_countries = pd.DataFrame(country_data)
world = world.merge(df_countries, left_on='name', right_on='country', how='left')

# Normalize the data
norm_avg = Normalize(vmin=min_avg_ratio, vmax=max_avg_ratio)
norm_std = Normalize(vmin=0, vmax=max_std_ratio)

# Create HSV color mapping function
def get_hsv_color(avg_ratio, std_ratio):
    h = 1 - norm_avg(avg_ratio)  # Hue based on average ratio (inverted)
    s = norm_std(std_ratio)  # Saturation based on standard deviation
    s = max(0.3, s)  # Ensure minimum saturation
    v = 0.8  # Constant value for brightness
    return hsv_to_rgb((h, s, v))

# Plot
fig, ax = plt.subplots(1, 1, figsize=(20, 10))
fig.patch.set_facecolor('#e6e8ec')

for idx, row in world.iterrows():
    if pd.notnull(row['avg_ratio']) and pd.notnull(row['std_ratio']):
        color = get_hsv_color(row['avg_ratio'], row['std_ratio'])
    else:
        color = 'lightgrey'  # Default color for countries with missing data
    gpd.GeoSeries([row['geometry']]).plot(ax=ax, color=color, edgecolor='white', linewidth=0.5)

ax.set_axis_off()
plt.tight_layout()

# Create custom colorbar
sm = ScalarMappable(cmap='hsv_r', norm=Normalize(vmin=0, vmax=1))
sm.set_array([])
cbar = fig.colorbar(sm, ax=ax, orientation='horizontal', location='top', fraction=0.02, pad=0.01)
cbar.set_label('Preference towards reference frame over camera frame', fontsize=20, labelpad=10)
cbar.set_ticks([0, 1])
cbar.set_ticklabels([f'{min_avg_ratio:.2f}', f'{max_avg_ratio:.2f}'], fontdict={'fontsize': 12})
cbar.ax.xaxis.set_tick_params(labelsize=12, direction='out', pad=5)
cbar.ax.xaxis.set_ticks_position('bottom')

# # Add a note about saturation
# plt.text(0.5, 0.95, 'Color saturation indicates variation (standard deviation)', 
#          horizontalalignment='center', verticalalignment='center', transform=ax.transAxes)

plt.savefig('world_map_hsv.png', bbox_inches='tight', dpi=300)

# Print sorted ratios
sorted_avg_ratios = {k: v for k, v in sorted(average_ratios.items(), key=lambda item: item[1], reverse=True)}
sorted_std_ratios = {k: v for k, v in sorted(std_ratios.items(), key=lambda item: item[1], reverse=True)}
print("Average Ratios (Descending Order):")
for lang_code, avg_ratio in sorted_avg_ratios.items():
    print(f"{lang_code}: {avg_ratio}")
print("\nStandard Deviations of Ratios (Descending Order):")
for lang_code, std_ratio in sorted_std_ratios.items():
    print(f"{lang_code}: {std_ratio}")