import os
import json
from utils.constants import Constants
current_dir = os.path.dirname(os.path.abspath(__file__))


table_1 = r"""% Please add the following required packages to your document preamble:
% \usepackage{booktabs}
% \usepackage[table,xcdraw]{xcolor}
% Beamer presentation requires \usepackage{colortbl} instead of \usepackage[table,xcdraw]{xcolor}
\begin{table}[]
\centering
\setlength{\tabcolsep}{1mm}
\begin{tabular}{@{}l|cc|cc|cc|cc@{}}
\toprule
\multicolumn{1}{c}{Stations}           & \multicolumn{2}{c}{RLinear}            & \multicolumn{2}{c}{RLinear-1}                           & \multicolumn{2}{c}{RLinear-LG}                          & \multicolumn{2}{c}{RLinear-STLG}                         \\ \midrule
\multicolumn{1}{c|}{Metrics}           & RMSE    & \multicolumn{1}{c|}{MAE}     & \multicolumn{1}{c}{RMSE} & \multicolumn{1}{c|}{MAE}     & \multicolumn{1}{c}{RMSE} & \multicolumn{1}{c|}{MAE}     & \multicolumn{1}{c}{RMSE}       & \multicolumn{1}{c}{MAE} \\ \midrule
\multicolumn{1}{l|}{Ban Chot}          & 51.15   & \multicolumn{1}{c|}{19.12}   & 51.15                    & \multicolumn{1}{l|}{19.12}   & 51.15                    & \multicolumn{1}{l|}{19.12}   & 49.47                          & 18.66                   \\
\multicolumn{1}{l|}{Ban Huai Khayuong} & 30.69   & \multicolumn{1}{c|}{13.65}   & 29.53                    & \multicolumn{1}{l|}{12.53}   & 29.53                    & \multicolumn{1}{l|}{12.53}   & 29.53                          & 12.53                   \\
\multicolumn{1}{l|}{Ban Huai Yano Mai} & 3.34    & \multicolumn{1}{c|}{1.66}    & 3.29                     & \multicolumn{1}{l|}{1.62}    & 3.29                     & \multicolumn{1}{l|}{1.62}    & 3.29                           & 1.62                    \\
\multicolumn{1}{l|}{Ban Kengdone}      & 666.36  & \multicolumn{1}{c|}{316.55}  & 661.04                   & \multicolumn{1}{l|}{310.65}  & 661.04                   & \multicolumn{1}{l|}{310.65}  & 661.04                         & 310.65                  \\
\multicolumn{1}{l|}{Ban Na Luang}      & 71.80   & \multicolumn{1}{c|}{36.64}   & 71.45                    & \multicolumn{1}{l|}{35.68}   & 71.45                    & \multicolumn{1}{l|}{35.68}   & 71.45                          & 35.68                   \\
\multicolumn{1}{l|}{Ban Nong Kiang}    & 27.73   & \multicolumn{1}{c|}{11.01}   & 27.73                    & \multicolumn{1}{l|}{11.05}   & 27.73                    & \multicolumn{1}{l|}{11.05}   & 27.73                          & 11.05                   \\
\multicolumn{1}{l|}{Ban Pak Huai}      & 92.79   & \multicolumn{1}{c|}{27.15}   & 92.39                    & \multicolumn{1}{l|}{27.04}   & 92.39                    & \multicolumn{1}{l|}{27.04}   & 92.39                          & 27.04                   \\
\multicolumn{1}{l|}{Ban Pak Kanhoung}  & 213.96  & \multicolumn{1}{c|}{113.93}  & 213.96                   & \multicolumn{1}{l|}{113.93}  & 213.96                   & \multicolumn{1}{l|}{113.93}  & 205.08 & 112.25                  \\
\multicolumn{1}{l|}{Ban Tad Ton}       & 8.49    & \multicolumn{1}{c|}{3.56}    & 8.45                     & \multicolumn{1}{l|}{3.54}    & 8.45                     & \multicolumn{1}{l|}{3.54}    & 8.45                           & 3.54                    \\
\multicolumn{1}{l|}{Ban Tha Mai Liam}  & 29.69   & \multicolumn{1}{c|}{13.96}   & 29.16                    & \multicolumn{1}{l|}{14.00}   & 29.18                    & \multicolumn{1}{l|}{14.00}   & 28.82                          & 13.88                   \\
\multicolumn{1}{l|}{Ban Tha Ton}       & 27.89   & \multicolumn{1}{c|}{14.40}   & 28.03                    & \multicolumn{1}{l|}{14.42}   & 28.03                    & \multicolumn{1}{l|}{14.42}   & 28.03                          & 14.42                   \\
\multicolumn{1}{l|}{Cau 14 (Buon Bur)} & 211.42  & \multicolumn{1}{c|}{71.65}   & 229.77                   & \multicolumn{1}{l|}{80.59}   & 234.75                   & \multicolumn{1}{l|}{81.55}   & 219.79                         & 76.73                   \\
\multicolumn{1}{l|}{Chaktomuk}         & 215.86  & \multicolumn{1}{c|}{117.62}  & 218.10                   & \multicolumn{1}{l|}{117.90}  & 218.10                   & \multicolumn{1}{l|}{117.90}  & 204.14                         & 110.43                  \\
\multicolumn{1}{l|}{Chiang Khan}       & 1439.03 & \multicolumn{1}{c|}{752.94}  & 1282.24                  & \multicolumn{1}{l|}{694.26}  & 1285.92                  & \multicolumn{1}{l|}{694.40}  & 1294.88                        & 716.26                  \\
\multicolumn{1}{l|}{Chiang Saen}       & 965.39  & \multicolumn{1}{c|}{547.77}  & 978.16                   & \multicolumn{1}{l|}{550.84}  & 978.16                   & \multicolumn{1}{l|}{550.84}  & 978.73                         & 558.31                  \\
\multicolumn{1}{l|}{Duc Xuyen}         & 141.45  & \multicolumn{1}{c|}{41.14}   & 140.49                   & \multicolumn{1}{l|}{39.55}   & 140.49                   & \multicolumn{1}{l|}{39.55}   & 140.49                         & 39.55                   \\
\multicolumn{1}{l|}{Khong Chiam}       & 2095.19 & \multicolumn{1}{c|}{1109.83} & 1914.30                  & \multicolumn{1}{l|}{1080.67} & 1911.17                  & \multicolumn{1}{l|}{1045.38} & 1966.37                        & 1112.59                 \\
\multicolumn{1}{l|}{Kompong Cham}      & 3232.64 & \multicolumn{1}{c|}{1648.32} & 2723.24                  & \multicolumn{1}{l|}{1361.43} & 2788.17                  & \multicolumn{1}{l|}{1405.78} & 2726.50                        & 1373.48                 \\
\multicolumn{1}{l|}{Kontum}            & 66.67   & \multicolumn{1}{c|}{32.61}   & 66.52                    & \multicolumn{1}{l|}{32.46}   & 66.52                    & \multicolumn{1}{l|}{32.46}   & 66.52                          & 32.46                   \\
\multicolumn{1}{l|}{Mukdahan}          & 1782.39 & \multicolumn{1}{c|}{953.00}  & 1756.35                  & \multicolumn{1}{l|}{935.42}  & 1756.35                  & \multicolumn{1}{l|}{935.42}  & 1756.35                        & 935.42                  \\
\multicolumn{1}{l|}{Nakhon Phanom}     & 1593.65 & \multicolumn{1}{c|}{883.49}  & 1742.43                  & \multicolumn{1}{l|}{1028.18} & 1722.25                  & \multicolumn{1}{l|}{1018.63} & 1603.75                        & 934.91                  \\
\multicolumn{1}{l|}{Nong Khai}         & 1313.81 & \multicolumn{1}{c|}{704.73}  & 1355.75                  & \multicolumn{1}{l|}{742.49}  & 1353.31                  & \multicolumn{1}{l|}{739.05}  & 1319.09                        & 709.42                  \\
\multicolumn{1}{l|}{Pakse}             & 2359.31 & \multicolumn{1}{c|}{1259.04} & 2339.16                  & \multicolumn{1}{l|}{1240.41} & 2339.16                  & \multicolumn{1}{l|}{1240.41} & 2339.16                        & 1240.41                 \\
\multicolumn{1}{l|}{Stung Treng}       & 4104.71 & \multicolumn{1}{c|}{2120.51} & 4192.56                  & \multicolumn{1}{l|}{2194.98} & 4192.56                  & \multicolumn{1}{l|}{2194.98} & 3769.10                        & 2045.71                 \\
\multicolumn{1}{l|}{Vientiane KM4}     & 1445.05 & \multicolumn{1}{c|}{770.77}  & 1297.17                  & \multicolumn{1}{l|}{701.28}  & 1288.97                  & \multicolumn{1}{l|}{698.58}  & 1330.17                        & 716.35                  \\
\multicolumn{1}{l|}{Yasothon}          & 108.76  & \multicolumn{1}{c|}{46.39}   & 109.00                   & \multicolumn{1}{l|}{46.96}   & 109.00                   & \multicolumn{1}{l|}{46.96}   & 109.04                         & 45.87                   \\ \midrule
\multicolumn{1}{l|}{Avg}               & 857.66  & \multicolumn{1}{c|}{447.36}  & 829.29                   & \multicolumn{1}{l|}{438.88}  & 830.81                   & \multicolumn{1}{l|}{438.67}  & 808.82                         & 431.12                  \\ \bottomrule
\end{tabular}
\caption{Ablation results on RLinear with lookback window length $L=32$ and future horizon length $H=7$. -1 denotes phase 1, -LG denotes phase local-global, -STLG denotes local-global with seasonal-trend flow detection. Results with colors denotes fusion from other stations.}
\label{tab:my-table}
\end{table}"""


class HighlightTable:
    def __init__(self, table_list, float_num=2):
        self.table_list = table_list
        self.float_num = float_num

    def highlight_cell_background_when_station_fused(self, color, col_list, station_list):
        self._color = color
        self._col_list = col_list
        self._station_list = station_list

        all_split_strings = []
        for table in self.table_list:
            split_strings, mse_list, mae_list = self._get_all_num_list(table)
            all_split_strings.append(split_strings)

        all_recombined_table = []
        for i in range(len(all_split_strings)):
            all_split_strings[i] = self._color_cells(all_split_strings[i])
            recombined_table = self._get_combined_string(all_split_strings[i])
            all_recombined_table.append(recombined_table)

        all_str = ""
        for table in all_recombined_table:
            all_str += "\n\n\n\n\n"
            all_str += table
            all_str += "\n"
        return all_str

    def _check_value_condition(self, s):
        """
        check this row is the value row or not
        """
        return '.' in s and 'caption' not in s

    def _check_row_condition(self, row):
        condition = False
        for station in self._station_list:
            if station in row:
                condition = True
                break
        return condition

    def _check_col_condition(self, col_index):
        condition = False
        if col_index in self._col_list:
            condition = True
        return condition

    def clean_table(self, string):
        # 按照\\分割字符串
        split_strings = string.split('\\\\')
        for i, s in enumerate(split_strings):
            if self._check_value_condition(s):
                split_row = s.split('&')
                for j, sr in enumerate(split_row):
                    if j == 0:
                        continue
                    num_sr = sr.replace(" ", "")
                    for remove_item in [
                        r'\multicolumn{1}{c|}{',
                        r'\multicolumn{1}{c}{',
                        r'\multicolumn{1}{l}{',
                        r'\multicolumn{1}{l|}{',
                        r'\textbf{',
                        r'\ul',
                    ]:
                        num_sr = num_sr.replace(remove_item, "")
                    if r'\cellcolor[HTML]{' in num_sr:
                        num_sr = num_sr.split(r'\cellcolor[HTML]{')[1]
                        num_sr = num_sr[6:]
                    if r'\color[HTML]{' in num_sr:
                        num_sr = num_sr.split(r'\color[HTML]{')[1]
                        num_sr = num_sr[6:]
                    num_sr = num_sr.replace(r'{', "")
                    while '}' in num_sr and '.' in num_sr:
                        num_sr = num_sr[:-1]
                    split_row[j] = num_sr
                new_s = r'&'.join(split_row)
                split_strings[i] = new_s
                # print(split_strings)
        recombined_string = r'\\'.join(split_strings)
        return recombined_string

    def _get_all_num_list(self, string):
        # 去掉所有空格
        # long_string_no_spaces = string.replace(" ", "")

        # 按照\\分割字符串
        split_strings = string.split('\\\\')

        # 输出分割后的字符串列表
        mse_list = []
        mae_list = []
        num_row_index = []
        for i, s in enumerate(split_strings):
            mse_row_list = []
            mae_row_list = []
            if self._check_value_condition(s):
                split_row = s.split('&')
                for j, sr in enumerate(split_row):
                    if j == 0:
                        continue
                    num_sr = sr
                    try:
                        num = float(num_sr)
                        if j % 2 == 0:
                            mse_row_list.append(num)
                        else:
                            mae_row_list.append(num)
                    except ValueError:
                        num = 99
                        if j % 2 == 0:
                            mse_row_list.append(num)
                        else:
                            mae_row_list.append(num)
                        continue
                if mse_row_list:
                    mse_list.append(mse_row_list)
                if mae_row_list:
                    mae_list.append(mae_row_list)
                num_row_index.append(i)
        #     else:
        #         mae_list.append(mae_row_list)
        # print(mse_list)
        # print(num_row_index)
        return split_strings, mse_list, mae_list

    def _color_cells(self, split_strings):
        for i, s in enumerate(split_strings):
            if self._check_value_condition(s):
                # if num_row_index_offset is None:
                #     num_row_index_offset = i
                if self._check_row_condition(s):
                    split_row = s.split('&')
                    for j, sr in enumerate(split_row):
                        new_sr = sr
                        if self._check_col_condition(j):
                            num_sr = sr
                            try:
                                num = float(num_sr)
                            except ValueError:
                                num = 99
                            color = '{' + self._color + '}'
                            new_sr = sr.replace(str(f'{num:.{self.float_num}f}'),
                                                (fr'\cellcolor[HTML]{color} {num_sr}'))
                        split_row[j] = new_sr
                    new_s = r'&'.join(split_row)
                    split_strings[i] = new_s
        return split_strings

    def _get_combined_string(
            self,
            split_strings,
    ):
        recombined_string = r'\\'.join(split_strings)
        return recombined_string


all_stations = Constants().all_stations
station_list = []
with open(os.path.join(current_dir, '..', 'dataset', 'st_useful_other_dict.json'), 'r', encoding='utf-8') as f:
    best_other_dict = json.load(f)
for station in all_stations:
    seasonal_other_list = best_other_dict['seasonal'][station]
    trend_other_list = best_other_dict['trend'][station]
    if seasonal_other_list + trend_other_list:
        station_list.append(station)
col_list = [7, 8]

cleaned_table = HighlightTable([table_1]).clean_table(table_1)
print(cleaned_table)
highlight_table_value = HighlightTable([cleaned_table]).highlight_cell_background_when_station_fused(
    color='96FFFB',
    col_list=col_list,
    station_list=station_list)

col_list = [3, 4, 5, 6]
station_list = []
with open(os.path.join(current_dir, '..', 'dataset', 'useful_other_dict.json'), 'r', encoding='utf-8') as f:
    best_other_dict = json.load(f)
for station in all_stations:
    other_list = best_other_dict[station]
    if other_list:
        station_list.append(station)
highlight_table_value = HighlightTable([highlight_table_value]).highlight_cell_background_when_station_fused(
    color='96FFFB',
    col_list=col_list,
    station_list=station_list)
print(highlight_table_value)


