import pandas as pd

#dictionary to standardize country names
country_name_mapping = {
    'Bahamas': 'Bahamas, The',
    'Bali': 'Indonesia',
    'Bolivia (Plurinational State of)': 'Bolivia',
    'Brunei': 'Brunei Darussalam',
    'Cape Verde': 'Cabo Verde',
    'Cayman_Islands.': 'Cayman Islands',
    'Congo-Brazzaville': 'Congo, Rep.',
    'Czech Republic': 'Czechia',
    "Côte d'Ivoire": "Cote d'Ivoire",
    'Democratic Republic of the Congo': 'Congo, Dem. Rep.',
    'Dubai': 'United Arab Emirates',
    'East Timor': 'Timor-Leste',
    'Egypt': 'Egypt, Arab Rep.',
    'Federated States of Micronesia': 'Micronesia, Fed. Sts.',
    'Gambia': 'Gambia, The',
    'Guernsey': 'United Kingdom',
    'Guinea Bissau': 'Guinea-Bissau',
    'Hong Kong': 'Hong Kong SAR, China',
    'Iran': 'Iran, Islamic Rep.',
    'Ivory Coast': "Cote d'Ivoire",
    'Jersey': 'United Kingdom',
    'Kyrgyzstan': 'Kyrgyz Republic',
    'Laos': 'Lao PDR',
    'Macau': 'Macao SAR, China',
    'Mizoram': 'India',
    'Nazi state': 'Germany',
    'New_Zealand': 'New Zealand',
    'North Korea': "Korea, Dem. People's Rep.",
    'Northern Cyprus': 'Cyprus',
    'Northern_Cyprus': 'Cyprus',
    'Palestinian Territories': 'Palestine',
    'Philipines': 'Philippines',
    'Phillipines': 'Philippines',
    'Phillippines': 'Philippines',
    'Russia': 'Russian Federation',
    'Saint Kitts and Nevis': 'St. Kitts and Nevis',
    'Saint Lucia': 'St. Lucia',
    'Saint Vincent and the Grenadines': 'St. Vincent and the Grenadines',
    'Samoa': 'American Samoa',
    'San_Marino': 'San Marino',
    'Sint Maarten': 'Sint Maarten (Dutch part)',
    'Slovakia': 'Slovak Republic',
    'South Korea': 'Korea, Rep.',
    'South_Africa': 'South Africa',
    'Sri_Lanka': 'Sri Lanka',
    'Syria': 'Syrian Arab Republic',
    'São Tomé and Príncipe': 'Sao Tome and Principe',
    'The Gambia': 'Gambia, The',
    'Timor Leste': 'Timor-Leste',
    'Turkey': 'Turkiye',
    'U.S.A.': 'United States',
    'UK': 'United Kingdom',
    'USA': 'United States',
    'United States of America': 'United States',
    'United_Kingdom': 'United Kingdom',
    'Venezuela': 'Venezuela, RB',
    'Vietnam': 'Viet Nam',
    'Yemen': 'Yemen, Rep.',
    'Yishun, North Region, Singapore': 'Singapore'
}

def load_csv_as_dict(file_path, key_column, value_column, skiprows=0):
    """
    Helper function to load a CSV and return a dictionary 
    with the specified key and value columns.
    """
    df = pd.read_csv(file_path, skiprows=skiprows)  # Skip metadata rows if necessary
    return df[[key_column, value_column]].dropna().set_index(key_column)[value_column].to_dict()

def create_land_area_population_dict(land_area_file, population_file):
    """
    Function to create a dictionary with countries as keys and (land area, population) as values.
    """
    # Load land area and population data as dictionaries
    land_area_dict = load_csv_as_dict(land_area_file, 'Country Name', '2021')  # Adjust column names as needed
    population_dict = load_csv_as_dict(population_file, 'Country Name', '2021')

    return {
        country: (
            land_area_dict.get(country, None),
            population_dict.get(country, None),
        )
        for country in set(land_area_dict) | set(population_dict)
    }

def clean_country_data(country_data_dict, dataset_df):
    """
    Remove countries not in the provided dataset from the combined country data dictionary.
    """
    # Find countries that are not in dataset_df's 'reverse_geo.address.country' column
    missing_countries = set(country_data_dict.keys()) - set(dataset_df['iso_country'].unique())

    # Remove missing countries from country_data
    for country in missing_countries:
        del country_data_dict[country]
    
    return country_data_dict


def calculate_sample_size(country_data, total_samples=15000, alpha=0.3, beta=0.7, min_samplesize=3):
    """
    Calculate the number of samples for each country based on weighted land area and population.
    """
    # Convert the dictionary into a DataFrame
    df_country = pd.DataFrame(country_data).T.reset_index()
    df_country.columns = ['Country', 'Land Area', 'Population']
    df_country['Population'] = pd.to_numeric(df_country['Population'], errors='coerce')
    df_country['Land Area'] = pd.to_numeric(df_country['Land Area'], errors='coerce')

    # Fill NaN values with 0 for Land Area and Population
    df_country.loc[:, 'Land Area'] = df_country['Land Area'].fillna(0)
    df_country.loc[:, 'Population'] = df_country['Population'].fillna(0)

    # Calculate weights based on population and land area
    df_country['Weight'] = alpha * df_country['Population'] + beta * df_country['Land Area']

    # Normalize the weights so they sum to 1
    df_country['Normalized Weight'] = df_country['Weight'] / df_country['Weight'].sum()

    # Compute the number of samples for each country, rounding the results
    df_country['Country Sample Size'] = (df_country['Normalized Weight'] * total_samples).round().astype(int)

    # Replace all zero sample sizes with the minimum non-zero sample size
    df_country['Country Sample Size'] = df_country['Country Sample Size'].apply(
        lambda x: min_samplesize if x < min_samplesize else x
    )

    return df_country


def generate_country_samples(land_area_file, population_file, df_dropped, total_samples=15100, alpha=0.3, beta=0.7, min_samplesize=3, year_column='2021'):
    """
    Main function to generate country sample sizes based on land area and population.
    """
    # Step 1: Load country data for land area and population
    country_data = create_land_area_population_dict(land_area_file, population_file)

    # Step 2: Remove countries not in df_dropped
    country_data = clean_country_data(country_data, df_dropped)

    return calculate_sample_size(
        country_data, total_samples, alpha, beta, min_samplesize
    )

def sample_country_data(df_samplesizes, df_dataset, min_samplesize, country_column='iso_country', random_state=42):
    """
    Function to sample data from each country based on the sample size specified in df_country.

    Args:
    df_samplesizes (DataFrame): DataFrame containing 'Country' and 'Country Sample Size'.
    df_dataset (DataFrame): The DataFrame from which data will be sampled.
    min_samplesize (int): The minimum number of samples to take if the country has enough data.
    country_column (str): The column name in df_dataset that holds country information.
    random_state (int): Random seed for reproducibility.

    Returns:
    DataFrame: A DataFrame containing the sampled data from all countries.
    """
    sampled_dfs = []

    for _, row in df_samplesizes.iterrows():
        country = row['Country']
        sample_size = row['Country Sample Size']

        # Filter the second DataFrame for the current country
        country_data = df_dataset[df_dataset[country_column] == country]

        # Check if there are enough rows to sample
        available_rows = len(country_data)

        if available_rows >= sample_size:
            # Sample without replacement if enough rows are available
            df_sample = country_data.sample(n=sample_size, replace=False, random_state=random_state)
        elif available_rows >= min_samplesize:
            # If not enough rows are available, take all available rows
            df_sample = country_data.sample(n=min_samplesize, replace=False, random_state=random_state)
        else:
            # Skip if no data is available for the country
            print(f"Skipping {country}: not enough data available.")
            continue

        # Append the sampled DataFrame to the list
        sampled_dfs.append(df_sample)

    if sampled_dfs:
        return pd.concat(sampled_dfs, ignore_index=True)
    
    print("No data was sampled.")
    return pd.DataFrame()  # Return an empty DataFrame if no data was sampled