import ast
import yaml
import torch
import numpy as np
import xarray as xr
import seaborn as sns
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib.pyplot as plt
from .mfrnp import Emulator
from .utils import get_test_data
import os

base_dir = os.path.dirname(os.path.abspath(__file__))
model_config_pth = os.path.join(base_dir, "model", "tas_reanalysis.yaml")
model_checkpoint_pth = os.path.join(base_dir, "model", "checkpoints", "tas_model.pt")
model_z_dict_pth = os.path.join(base_dir, "model", "checkpoints", "z_dict.pth")
settings = ["ssp126", "ssp245", "ssp370", "ssp585"]

emulator = Emulator(model_config_pth, model_checkpoint_pth, model_z_dict_pth)


def ll2index(longitude, latitude):
    index_0 = int((90 - latitude) * 4)
    index_1 = int((longitude + 180) * 1439 / 360)
    return index_0, index_1


def ll2index2(longitude, latitude):
    mapped_longitude = int((longitude + 180) / 360 * 144)

    mapped_latitude = int((latitude + 90) / 180 * 96)

    mapped_longitude = max(0, min(mapped_longitude, 143))
    mapped_latitude = max(0, min(mapped_latitude, 95))

    return mapped_latitude, mapped_longitude


def diy_greenhouse(longitude, latitude, setting, year, delta_CO2=0, delta_CH4=0):
    """
    Predict the temperature of a place in the future under a specific climate scenario with DIY change of CO2 and CH4 based on the original setting.

    Args:
        longitude: The longitude of the place you would check the temperature for, a float from -180 to 180.
        latitude: The latitude of the place you would check the temperature for, a float from -90 to 90.
        setting: Future climate scenarios, a string from ssp126, ssp245, ssp370, ssp585.
        year: The year you would check the temperature for, an integer from 2015 to 2100.
        delta_CO2: The change of CO2 you would like to make, a float. CO2_after = CO2_before * (1 + delta_CO2).
        delta_CH4: The change of CH4 you would like to make, a float. CH4_after = CH4_before * (1 + delta_CH4).
    """
    year = int(year)
    longitude = float(longitude)
    latitude = float(latitude)
    delta_CO2 = float(delta_CO2)
    delta_CH4 = float(delta_CH4)

    index_0, index_1 = ll2index(longitude, latitude)

    if year < 2015 or year > 2100:
        return "We only have future data from 2015 to 2100."

    x = np.load(os.path.join(base_dir, "data", "x_all.npy"))[
        165 + settings.index(setting) * 86 + (year - 2015)
    ].reshape(1, 12)

    x[0, 0] *= 1 + delta_CO2
    x[0, 1] *= 1 + delta_CH4

    y = emulator.pred(x).reshape(721, 1440)

    return y[index_0][index_1], f"The temperature is {y[index_0][index_1]}."


def diy_aerosol(longitude, latitude, setting, year, delta_SO2, delta_BC, modify_points):
    """
    Predict the temperature of a place in the future under a specific climate scenario with DIY change of SO2 and BC based on the original setting.

    Args:
        longitude: The longitude of the place you would check the temperature for, a float from -180 to 180.
        latitude: The latitude of the place you would check the temperature for, a float from -90 to 90.
        setting: Future climate scenarios, a string from ssp126, ssp245, ssp370, ssp585.
        year: The year you would check the temperature for, an integer from 2015 to 2100.
        delta_SO2: The change of SO2 you would like to make, a float.
        delta_BC: The change of BC you would like to make, a float.
        modify_points: The points you would like to modify, a string representing a list of (lon, lat) tuples.
    """
    longitude = float(longitude)
    latitude = float(latitude)
    year = int(year)
    delta_SO2 = float(delta_SO2)
    delta_BC = float(delta_BC)

    if year < 2015 or year > 2100:
        return "We only have future data from 2015 to 2100."

    modification_method = "percent"

    modify_points = ast.literal_eval(modify_points)
    modify_points = [
        ll2index2(longitude, latitude) for longitude, latitude in modify_points
    ]

    index_0, index_1 = ll2index(longitude, latitude)

    # Get baseline data (no modifications)
    x = get_test_data(
        year - 2015, setting, "SO2", 0, modification_method, modify_points
    )
    
    # Apply SO2 modification if needed
    if delta_SO2 != 0:
        x_so2 = get_test_data(
            year - 2015, setting, "SO2", delta_SO2, modification_method, modify_points
        )
        # Copy SO2 modifications (columns 2-6) to baseline data
        x.iloc[:, 2:7] = x_so2.iloc[:, 2:7]
    
    # Apply BC modification if needed
    if delta_BC != 0:
        x_bc = get_test_data(
            year - 2015, setting, "BC", delta_BC, modification_method, modify_points
        )
        # Copy BC modifications (columns 7-11) to current data
        x.iloc[:, 7:12] = x_bc.iloc[:, 7:12]

    x = x.iloc[year - 2015].values.reshape(1, -1)
    y = emulator.pred(x).reshape(721, 1440)

    return y[index_0][index_1], f"The temperature is {y[index_0][index_1]}."


def diy_aerosol_mean(setting, year, delta_SO2, delta_BC, modify_points):
    """
    Predict the temperature of a place in the future under a specific climate scenario with DIY change of SO2 and BC based on the original setting.

    Args:
        setting: Future climate scenarios, a string from ssp126, ssp245, ssp370, ssp585.
        year: The year you would check the temperature for, an integer from 2015 to 2100.
        delta_SO2: The change of SO2 you would like to make, a float.
        delta_BC: The change of BC you would like to make, a float.
        modify_points: The points you would like to modify, a string representing a list of (lon, lat) tuples.
    """
    year = int(year)
    delta_SO2 = float(delta_SO2)
    delta_BC = float(delta_BC)

    if year < 2015 or year > 2100:
        return "The year should be between 2015 and 2100.", 0

    modification_method = "percent"

    try:
        modify_points = ast.literal_eval(modify_points)
        modify_points = [
            ll2index2(longitude, latitude) for longitude, latitude in modify_points
        ]
    except (ValueError, SyntaxError):
        return "Invalid format for modify_points. It should be a string representing a list of tuples, e.g., '[(-10, 20)]'.", 0

    # Get baseline data (no modifications)
    x = get_test_data(
        year - 2015, setting, "SO2", 0, modification_method, modify_points
    )
    
    # Apply SO2 modification if needed
    if delta_SO2 != 0:
        x_so2 = get_test_data(
            year - 2015, setting, "SO2", delta_SO2, modification_method, modify_points
        )
        # Copy SO2 modifications (columns 2-6) to baseline data
        x.iloc[:, 2:7] = x_so2.iloc[:, 2:7]
    
    # Apply BC modification if needed
    if delta_BC != 0:
        x_bc = get_test_data(
            year - 2015, setting, "BC", delta_BC, modification_method, modify_points
        )
        # Copy BC modifications (columns 7-11) to current data
        x.iloc[:, 7:12] = x_bc.iloc[:, 7:12]

    x = x.iloc[year - 2015].values.reshape(1, -1)
    y = emulator.pred(x).reshape(721, 1440)
    
    mean_temp = np.mean(y)
    
    return f"The average temperature is {mean_temp}.", mean_temp


def diff_diy_aerosol_mean(setting, year, delta_SO2, delta_BC, modify_points):
    """
    Predict the temperature difference of a place in the future under a specific climate scenario with DIY change of SO2 and BC.

    Args:
        setting: Future climate scenarios, a string from ssp126, ssp245, ssp370, ssp585.
        year: The year you would check the temperature for, an integer from 2015 to 2100.
        delta_SO2: The change of SO2 you would like to make, a float.
        delta_BC: The change of BC you would like to make, a float.
        modify_points: The points you would like to modify, a string representing a list of (lon, lat) tuples.
    """
    _, temp_with_aerosol = diy_aerosol_mean(setting, year, delta_SO2, delta_BC, modify_points)
    _, temp_baseline = diy_aerosol_mean(setting, year, 0, 0, modify_points)
    
    temp_diff = temp_with_aerosol - temp_baseline
    
    return f"The average temperature difference is {temp_diff}.", temp_diff


def diy_greenhouse_summary(longitude, latitude, delta_CO2=0, delta_CH4=0):
    """
    Provide a summary of the temperature changes for a specific location and range of climate scenarios.

    Args:
        longitude: The longitude of the place, a float from -180 to 180.
        latitude: The latitude of the place, a float from -90 to 90.
        delta_CO2: The change of CO2, a float. CO2_after = CO2_before * (1 + delta_CO2).
        delta_CH4: The change of CH4, a float. CH4_after = CH4_before * (1 + delta_CH4).
    """
    index_0, index_1 = ll2index(longitude, latitude)

    CO2_change = (
        f"The emission of CO2 is {'increased' if delta_CO2 > 0 else 'decreased'} by {delta_CO2}%."
        if delta_CO2 != 0
        else ""
    )
    CH4_change = (
        f"The emission of CH4 is {'increased' if delta_CH4 > 0 else 'decreased'} by {delta_CH4}%."
        if delta_CH4 != 0
        else ""
    )

    return_string = f"\nFollowing is the temperature under different scenarios if: {CO2_change}{CH4_change}\n\n"

    for setting in settings:
        for year in [2050, 2100]:
            x = np.load(os.path.join(base_dir, "data", "x_all.npy"))[
                165 + settings.index(setting) * 86 + (year - 2015)
            ].reshape(1, 12)

            x[0, 0] *= 1 + delta_CO2
            x[0, 1] *= 1 + delta_CH4
            y = emulator.pred(x).reshape(721, 1440)

            return_string += f"the temperature in {year} under {setting} scenario is {y[index_0][index_1]},"

    return_string = return_string[:-1] + "."
    return return_string


if __name__ == "__main__":
    print(diy_greenhouse(-84.08, 9.9325, "ssp245", 2093, 0, -0.05))
    print(diy_greenhouse(-84.08, 9.9325, "ssp245", 2093, 0, -0.05))
    print(diy_greenhouse(-84.08, 9.9325, "ssp245", 2093, 0, -0.05))
    print(diy_greenhouse(-84.08, 9.9325, "ssp245", 2093, 0, -0.05))
