"""
Evaluate a trained RL policy on water distribution system data for 2012 or any specified year.
"""

import math
import datetime
from typing import List, Tuple
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
import logging

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Constants
DATA_PATH = './Data/'
MODEL_PATH = './MODEL/WDS_Q/0'
STATE_SPACE = 25  # tank_level, water_consumption, time_of_day, month, last_action, time_running, water_turnover
STATE_SIZE = 4    # Expanded state size
YEAR = '2012'  # Default year; can be overridden if parameterized
MONTHS = ['Januar', 'Februar', 'Maerz', 'April', 'Mai', 'Juni', 'Juli', 'August', 'September', 'Oktober', 'November', 'Dezember']
MONTHSL = ['January', 'February', 'March', 'April', 'May', 'June', 'July', 'August', 'September', 'October', 'November', 'December']
ACTIONS = [[0,0,0,0], [1,0,0,0], [0,1,0,0], [0,0,1,0], [0,0,0,1]]
MONTH_ENCODING = [list(np.eye(12, dtype=int)[i]) for i in range(12)]

# Normalization bounds
MIN_TK1, MAX_TK1 = 47, 57
MIN_CONSUMPTION, MAX_CONSUMPTION = 1, 3240
MIN_TIME, MAX_TIME = 59, 86399  # Seconds in a day
MIN_REWARD, MAX_REWARD = -18, 11  # From WDS_sim

class Evaluation:
    def __init__(self, year: str = YEAR) -> None:
        """Initialize the evaluation with a trained model and optional year."""
        self.year = year  # Allow year to be specified, default to 2012
        self.model = self._load_model()
        self.pump_curves, self.Qanlage, self.A, self.rho, self.g, self.pressure_coeffs = self._initialize_pump_data()
        self.evaluate()

    def _load_model(self) -> tf.keras.Model:
        """Load the trained model from MODEL_PATH."""
        try:
            model = tf.keras.models.load_model(MODEL_PATH)
            logging.info("Model loaded from %s", MODEL_PATH)
            return model
        except Exception as e:
            logging.error("Failed to load model from %s: %s", MODEL_PATH, e)
            raise

    def _initialize_pump_data(self) -> Tuple[dict, np.ndarray, float, float, float, List[float]]:
        """Initialize pump characteristic curves and physical constants (from WDS_sim)."""
        pumps = {
            1: {'Q': [0.0, 320.0, 570.0, 670.0, 897.0, 1099.0, 1298.0, 1597.3, 1823.2],
                'H': [76.81, 74.70, 71.61, 71.75, 72.29, 69.99, 67.95, 61.88, 55.79],
                'ETA': [0.00, 38.60, 55.55, 60.68, 72.28, 78.85, 83.12, 85.96, 85.07]},
            2: {'Q': [0.0, 240.0, 494.0, 694.0, 899.0, 1095.7, 1209.3, 1465.3],
                'H': [74.15, 72.15, 70.78, 69.34, 66.44, 62.75, 60.12, 52.57],
                'ETA': [0.00, 40.65, 63.22, 75.36, 83.02, 87.09, 88.32, 87.11]},
            3: {'Q': [0.0, 237.0, 470.0, 645.0, 797.3, 1000.3, 1208.7, 1410.0],
                'H': [63.84, 62.46, 61.98, 60.61, 58.68, 54.87, 50.07, 43.83],
                'ETA': [0.00, 38.28, 61.02, 72.36, 79.05, 84.94, 87.08, 84.91]},
            4: {'Q': [0.0, 195.0, 440.0, 659.7, 806.3, 998.7, 1204.0],
                'H': [64.13, 63.19, 61.62, 59.35, 56.61, 51.21, 43.71],
                'ETA': [0.00, 41.95, 67.50, 80.96, 84.84, 86.04, 83.42]}
        }
        A = math.pi * (44 ** 2) / 4
        rho, g = 1000, 9.81
        pressure_coeffs = [-2.713E-09, 0.000006504, 4.384E-07, -0.0001768]
        Qanlage = np.linspace(0, 2000, 2000)
        pump_curves = {
            i: {
                'H': np.polyval(np.polyfit(pumps[i]['Q'], pumps[i]['H'], 2), Qanlage),
                'ETA': np.polyfit(pumps[i]['Q'], pumps[i]['ETA'], 2)
            } for i in range(1, 5)
        }
        return pump_curves, Qanlage, A, rho, g, pressure_coeffs

    def _normalize_state(self, value: float, min_val: float, max_val: float) -> float:
        """Normalize a value to [0,1] range (from WDS_sim)."""
        return (value - min_val) / (max_val - min_val)

    def _clean_data(self, year: str, month: str):
        """Load and clean data for a given year and month, matching WormsSimulation intent."""
        try:
            # Load all DataFrames
            data_time = pd.read_csv(f"{DATA_PATH}{year}/WaterConsumption/{month}.csv", delimiter=';', skip_blank_lines=True)[['Time']]
            data_date = pd.read_csv(f"{DATA_PATH}{year}/WaterConsumption/{month}.csv", delimiter=';', decimal=',', skip_blank_lines=True)[['Date']]
            data_waterConsumption = pd.read_csv(f"{DATA_PATH}{year}/WaterConsumption/{month}.csv", delimiter=';', decimal=',', skip_blank_lines=True)[['Netzverbrauch_pval']]
            data_qr = [pd.read_csv(f"{DATA_PATH}{year}/NP/NP{j}/Q/{month}.csv", delimiter=';', decimal=',', skip_blank_lines=True)[[f'NP_{j}_Volumenfluss_pval']] for j in range(1, 5)]

            # Log initial sizes
            logging.info(f"{year}/{month}: Initial rows - Time: {len(data_time)}, WC: {len(data_waterConsumption)}, NP1: {len(data_qr[0])}, NP2: {len(data_qr[1])}, NP3: {len(data_qr[2])}, NP4: {len(data_qr[3])}")

            # Align all DataFrames to the shortest water-related length (excluding NP1 granularity)
            min_length = min(len(data_time), len(data_date), len(data_waterConsumption), len(data_qr[1]), len(data_qr[2]), len(data_qr[3]))
            data_time = data_time.iloc[:min_length]
            data_date = data_date.iloc[:min_length]
            data_waterConsumption = data_waterConsumption.iloc[:min_length]
            data_qr[1:] = [df.iloc[:min_length] for df in data_qr[1:]]  # NP2, NP3, NP4

            # First cleaning: Remove duplicate minutes in NP1 flow (only data_qr[0])
            data_time_np1_q = pd.read_csv(f"{DATA_PATH}{year}/NP/NP1/Q/{month}.csv", delimiter=';', skip_blank_lines=True)[['Time']].iloc[:len(data_qr[0])]
            indices_to_drop = []
            m_i = None
            for i, row in data_time_np1_q.iterrows():
                _, m, _ = row['Time'].split(':')
                if i == 0:
                    m_i = m
                elif m_i == m:
                    indices_to_drop.append(i)
                else:
                    m_i = m
            data_qr[0].drop(indices_to_drop, inplace=True)
            data_qr[0].reset_index(drop=True, inplace=True)

            # Pad or truncate data_qr[0] to match min_length
            if len(data_qr[0]) < min_length:
                padding = pd.DataFrame({data_qr[0].columns[0]: [np.nan] * (min_length - len(data_qr[0]))})
                data_qr[0] = pd.concat([data_qr[0], padding], ignore_index=True)
            elif len(data_qr[0]) > min_length:
                data_qr[0] = data_qr[0].iloc[:min_length]
            data_qr[0].reset_index(drop=True, inplace=True)

            # Log after minute cleaning
            logging.info(f"{year}/{month}: After minute cleaning - NP1 rows: {len(data_qr[0])}, Others: {len(data_time)}")

            # Second cleaning: Remove duplicate hours where minute is 0 (all DataFrames)
            indices_to_drop = []
            h_i = None
            for i, row in data_time.iterrows():
                h, m, _ = row['Time'].split(':')
                if i == 0:
                    h_i = h
                elif int(m) == 0 and h_i == h:
                    j = 0
                    while i + j < len(data_time):
                        h_next, m_next, _ = data_time.iloc[i + j, 0].split(':')
                        if h_i != h_next:
                            break
                        indices_to_drop.append(i + j)
                        j += 1
                    h_i = h
                else:
                    h_i = h

            all_dfs = [data_time, data_date, data_waterConsumption] + data_qr
            for df in all_dfs:
                df.drop(indices_to_drop, inplace=True)
                df.reset_index(drop=True, inplace=True)

            # Log final size
            logging.info(f"{year}/{month}: Final rows after hour cleaning - {len(data_time)}")
            return data_time, data_date, data_waterConsumption, *data_qr

        except FileNotFoundError as e:
            logging.error(f"File not found for {year}/{month}: {e}")
            return None, None, None, None, None, None, None

    def _calculate_reward(self, action_idx: int, last_action: int, QBP: float, etaBP: float, HB_current: float, HB_next: float, time_running: List[int], water_turnover: int) -> float:
        """Calculate normalized reward with water_turnover bonus on transition (matches WDS_sim)."""
        p = 1 + time_running[action_idx] if action_idx == last_action or action_idx == 0 or time_running[action_idx] == 0 else 30 + time_running[action_idx]
        c = min(abs(HB_next - 50) if HB_next < 50 else (1 if HB_next >= 57 else 0), 1)
        # Bonus: water_turnover = 0 and transitions from not in [50, 53) to in [50, 53)
        if water_turnover == 0 and (HB_current < 50 or HB_current >= 53) and 50 <= HB_next < 53:
            c = -1
        if action_idx > 0:
            if QBP <= 0:  # Safeguard against zero flow
                logging.warning("Zero QBP with action %d at time %d, HB_next=%f", action_idx, self.current_time, HB_next)
                reward = -c * 10 + math.log(1 / p)
            else:
                reward = math.exp(1 / (-QBP / etaBP)) - c * 10 + math.log(1 / p)
        else:
            reward = -c * 10 + math.log(1 / p)
        return self._normalize_state(reward, MIN_REWARD, MAX_REWARD)

    @tf.function(reduce_retracing=True)
    def _predict(self, state: tf.Tensor) -> tf.Tensor:
        """Predict Q-values for the given state."""
        return self.model(state, training=False)

    def get_action(self) -> int:
        """Get the action with the highest Q-value (matches original)."""
        qs = self._predict(tf.convert_to_tensor(self.current_state, dtype=tf.float32)).numpy()[0]
        action = np.argmax(qs)
        print(f"Time: {self.current_time:.2f}s, Action: {action}, Tank: {self.HB[-1]:.2f}m")
        return action

    def evaluate(self) -> None:
        """Evaluate the policy over the specified year with water_turnover in state and reward."""
        with open("results_evaluation.txt", "w") as f1, open("actions.txt", "w") as f2:
            # Metrics (from original)
            avg_electricity = np.zeros((4, 12))  # Per pump, per month
            avg_switches = np.zeros((4, 12))     # Per pump, per month
            avg_tank_level = np.zeros((24, 12))  # Hourly, per month
            daily_counts = np.zeros((24, 12))    # Number of daily averages per hour/month
            total_consumption = 0
            cumulative_rewards = []

            # State and simulation variables
            self.current_state = np.zeros((1, STATE_SIZE, STATE_SPACE))
            self.expanded_state = []
            time_running = [0] * 5
            last_action = 2
            pump_switch = 0
            water_turnover = 0
            self.HB = [53.42]
            p1, p2, p3, p4 = self.pressure_coeffs

            for indexm, month in enumerate(MONTHS):
                # Call _clean_data with both year and month, unpack all returns
                data_time, data_date, data_wc, *data_qr = self._clean_data(self.year, month)
                if data_time is None:  # Skip if data loading failed
                    continue
                data_length = len(data_wc)
                days = data_date['Date'].nunique()

                consume = [0.0] * 4
                switches = [0] * 4
                tk_sum = 0
                minutes_counter = 0
                h_i = None
                episode_reward = 0

                QBP = np.zeros(data_length)
                HBP = np.zeros(data_length)
                etaBP = np.ones(data_length)
                Hanlage = []

                last_valid_wc = 0
                for i in range(data_length):
                    wc = data_wc.iloc[i, 0] if not pd.isna(data_wc.iloc[i, 0]) and data_wc.iloc[i, 0] > 0 else last_valid_wc
                    last_valid_wc = wc
                    self.current_time = int(datetime.timedelta(**dict(zip(['hours', 'minutes', 'seconds'], map(int, data_time.iloc[i, 0].split(':'))))).total_seconds())

                    if i == 0:
                        self.HB = [self.HB[-1]]
                        h, _, _ = data_time.iloc[i, 0].split(':')
                        h_i = h

                    aa, bb, cc = p1 * wc + p2, p3 * wc + p4, self.HB[i]
                    Hanlage.append([aa * x**2 + bb * x + cc for x in self.Qanlage])

                    # State construction with water_turnover (matches WDS_sim)
                    state = [
                        self._normalize_state(self.HB[i], MIN_TK1, MAX_TK1),
                        self._normalize_state(wc, MIN_CONSUMPTION, MAX_CONSUMPTION),
                        self._normalize_state(self.current_time, MIN_TIME, MAX_TIME)
                    ] + MONTH_ENCODING[indexm] + ACTIONS[last_action] + [t / 1440 for t in time_running] + [water_turnover]
                    self.expanded_state.append(state)

                    if len(self.expanded_state) < STATE_SIZE:
                        action_index = last_action
                    elif len(self.expanded_state) == STATE_SIZE:
                        for j in range(STATE_SIZE):
                            self.current_state[0][j] = self.expanded_state[j]
                        action_index = self.get_action()
                    else:
                        self.expanded_state.pop(0)
                        for j in range(STATE_SIZE):
                            self.current_state[0][j] = self.expanded_state[j]
                        action_index = self.get_action()

                    # Pump switch logic (from original)
                    if action_index != last_action and pump_switch < 4:
                        pump_switch += 1
                        action_index = last_action
                    else:
                        pump_switch = 0

                    if action_index > 0:
                        HPumpe = self.pump_curves[action_index]['H']
                        ETAPumpe = self.pump_curves[action_index]['ETA']
                        res = np.argmin(np.abs(np.array(Hanlage[i]) - HPumpe))
                        QBP[i] = self.Qanlage[res]
                        HBP[i] = HPumpe[res]
                        etaBP[i] = np.polyval(ETAPumpe, QBP[i])
                    else:
                        QBP[i], HBP[i], etaBP[i] = 0, 0, 1

                    HB_i = self.HB[i] + (QBP[i] - wc) / 60 / self.A
                    self.HB.append(np.clip(HB_i, MIN_TK1, MAX_TK1))

                    # Electricity consumption (from original)
                    pbp = ((self.rho * self.g * QBP[i] / 3600 * HBP[i]) / 55 * etaBP[i]) / 1000
                    total_consumption += pbp
                    if action_index > 0:
                        consume[action_index - 1] += pbp
                        switches[action_index - 1] += 2 if action_index != last_action else 0

                    # Tank level averaging (corrected)
                    h, m, _ = data_time.iloc[i, 0].split(':')
                    if minutes_counter < 59:
                        minutes_counter += 1
                        tk_sum += self.HB[i]
                    if h_i != h or minutes_counter >= 59:  # Trigger on hour change or full hour
                        if minutes_counter > 0:  # Avoid division by zero
                            hourly_avg = tk_sum / minutes_counter
                            print(f"Month {indexm}, Hour {h_i}, Daily Avg: {hourly_avg:.2f}")
                            avg_tank_level[int(h_i), indexm] += hourly_avg
                            daily_counts[int(h_i), indexm] += 1
                        tk_sum = 0
                        minutes_counter = 0
                        h_i = h

                    # Reward calculation with transition-based water_turnover bonus
                    reward = self._calculate_reward(action_index, last_action, QBP[i], etaBP[i], self.HB[i], self.HB[-1], time_running, water_turnover)
                    episode_reward += reward
                    print(f"Reward: {reward:.4f}, Consumption: {pbp:.4f}kW")

                    # Update water_turnover *after* reward calculation
                    if water_turnover == 0 and 50 <= self.HB[-1] < 53:
                        water_turnover = 1

                    time_running[action_index] += 1
                    last_action = action_index
                    f2.write(f"{action_index}\n")

                    if self.current_time == MAX_TIME:
                        days += 1
                        cumulative_rewards.append(episode_reward)
                        episode_reward = 0
                        time_running = [0] * 5
                        water_turnover = 0

                # Normalize tank levels by number of daily averages
                for h in range(24):
                    if daily_counts[h, indexm] > 0:
                        avg_tank_level[h, indexm] /= daily_counts[h, indexm]
                        print(f"Month {indexm}, Hour {h}, Final Avg: {avg_tank_level[h, indexm]:.2f}")

                # Fix missing hour (March 25, 2012 — 2 AM)
                if self.year == '2012':
                    maerz_idx = MONTHS.index('Maerz')
                    expected_days = 31  # March has 31 days
                    hour = 2  # 2 AM is missing on one day
                    if daily_counts[hour, maerz_idx] == expected_days - 1:
                        # Adjust average back to expected scale
                        avg_tank_level[hour, maerz_idx] *= daily_counts[hour, maerz_idx]
                        avg_tank_level[hour, maerz_idx] /= expected_days
        
                # Store averages for electricity and switches
                for p in range(4):
                    avg_electricity[p, indexm] = consume[p] / days if days > 0 else 0
                    avg_switches[p, indexm] = switches[p] / days if days > 0 else 0

            self._log_and_plot_results(f1, avg_electricity, avg_switches, avg_tank_level, total_consumption, cumulative_rewards)

    def _log_and_plot_results(self, f1, electricity: np.ndarray, switches: np.ndarray, tank_level: np.ndarray,
                             total_consumption: float, rewards: List[float]) -> None:
        """Log and plot results matching the style of the reference code."""
        f1.write("### Evaluation Results ###\n")
        for p in range(4):
            f1.write(f"NP{p+1}: {electricity[p]}\n")
            f1.write(f"SW{p+1}: {switches[p]}\n")
        f1.write(f"TOTAL: {total_consumption:.2f}\n")
        f1.write(f"TK LEVEL: {tank_level}\n")
        f1.write(f"Cumulative Reward: {rewards}\n")

        # Electricity Consumption Plot
        #plt.figure(figsize=(10, 6))
        for p in range(4):
            plt.plot(MONTHSL, electricity[p], label=f'NP{p+1}')
        plt.xticks(rotation=45, ha='right', fontsize=12)
        if self.year == '2012':
            plt.ylabel('Average Energy Consumption (kW)', fontsize=14)
        else:
            plt.ylabel('Average Energy Consumption (kW)')
        plt.title(f'Average Energy Consumption ({self.year}) - REM', fontsize=15)
        plt.legend()
        plt.grid(alpha=0.2, linestyle='--')
        plt.tight_layout()
        plt.savefig(f'ElectricityConsumption{self.year}.png')
        plt.clf()

        # Tank Level Plot
        x = ['00:00', '01:00', '02:00', '03:00', '04:00', '05:00', '06:00', '07:00', '08:00', '09:00',
             '10:00', '11:00', '12:00', '13:00', '14:00', '15:00', '16:00', '17:00', '18:00', '19:00',
             '20:00', '21:00', '22:00', '23:00']
        plt.axhline(y=50, color='k', linestyle='--') 
        plt.axhline(y=57, color='k', linestyle='--')     
        #plt.figure(figsize=(12, 6))
        for m in range(12):
            plt.plot(x, tank_level[:, m], label=MONTHSL[m])
        plt.xticks(rotation=90, ha='right', fontsize=10)
        plt.xlabel('Time of Day', fontsize=14)
        if self.year == '2012':
            plt.ylabel('Average Level (m)', fontsize=14)
        else:
            plt.ylabel('Average Level (m)')
        plt.ylim([49.5, 57.5])
        plt.title(f'Tank Level per Month ({self.year}) - REM', fontsize=16)
        plt.legend()
        plt.grid(alpha=0.2, linestyle='--')
        plt.tight_layout()
        plt.savefig(f'TankLevel_evaluation{self.year}.png')
        plt.clf()

        # Pump Switches Plot
        #plt.figure(figsize=(10, 6))
        for p in range(4):
            plt.plot(MONTHSL, switches[p], label=f'NP{p+1}')
        plt.xticks(rotation=45, ha='right', fontsize=10)
        if self.year == '2012':
            plt.ylabel('Average Daily Switches ON/OFF', fontsize=14)
        else:
            plt.ylabel('Average Daily Switches ON/OFF')
        plt.title(f'Average Daily Switches ON/OFF ({self.year}) - REM', fontsize=16)
        plt.legend()
        plt.grid(alpha=0.2, linestyle='--')
        plt.tight_layout()
        plt.savefig(f'PumpSwitches_evaluation{self.year}.png')
        plt.clf()

        # Cumulative Reward Plot (styled consistently with reference)
        plt.figure(figsize=(10, 6))
        plt.plot(range(1, len(rewards) + 1), rewards, label='Reward')
        plt.xlabel('Day', fontsize=14)
        plt.ylabel('Cumulative Reward', fontsize=14)
        plt.title(f'Cumulative Reward per Day ({self.year}) - REM', fontsize=16)
        plt.legend()
        plt.grid(alpha=0.2, linestyle='--')
        plt.tight_layout()
        plt.savefig(f'Cumulative_Reward_evaluation{self.year}.png')
        plt.clf()

        f1.close()

if __name__ == "__main__":
    try:
        # Optionally allow year to be specified via command-line or default to 2012
        Evaluation()  # Uses YEAR='2012' by default
        # For flexibility, you could add: Evaluation('2023') or use argparse
    except Exception as e:
        logging.error("Evaluation failed: %s", e)