'''
Water distribution system simulator

'''

import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"  # Suppress TensorFlow warnings

import math
import datetime
import numpy as np
import pandas as pd
import tensorflow as tf
import argparse
import logging
import Algorithms.rem  # Module to train the model with REM
import Algorithms.ddrqn  # Module to train the model with DDRQN/BCQ
import Algorithms.maxmin  # Module to train the model with MAXMIN

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Constants
DATA_PATH = './Data/'
STATE_SPACE = 25  # Dimensions: tank_level, water_consumption, time_of_day, month, last_action, time_running, water_turnover
STATE_SIZE = 4    # Size of expanded state
YEARS = ['2013', '2014']
MONTHS = ['Januar', 'Februar', 'Maerz', 'April', 'Mai', 'Juni', 'Juli', 'August', 'September', 'Oktober', 'November', 'Dezember']
ACTIONS = [[0,0,0,0], [1,0,0,0], [0,1,0,0], [0,0,1,0], [0,0,0,1]]
MONTH_ENCODING = [list(np.eye(12, dtype=int)[i]) for i in range(12)]

# Normalization constants
MIN_TK1, MAX_TK1 = 47, 57
MIN_CONSUMPTION, MAX_CONSUMPTION = 1, 3240
MIN_TIME, MAX_TIME = 59, 86399  # Time of day in seconds
MIN_REWARD, MAX_REWARD = -18, 11  # Observed reward bounds

class WDS_sim:
    def __init__(self, algorithm='rem'):
        algorithm_map = {
            'rem': Algorithms.rem.REM,
            'ddqn': Algorithms.ddrqn.DDRQN,
            'maxmin': Algorithms.maxmin.MAXMIN
        }
        if algorithm not in algorithm_map:
            raise ValueError(f"Unknown algorithm '{algorithm}'. Choose from {list(algorithm_map.keys())}")
        self.qlearning = algorithm_map[algorithm]()
        self.run()
        self.qlearning.feed_memory()
        self.qlearning.train_model()

    def _initialize_pump_data(self):
        """Initialize pump characteristic curves and physical constants."""
        pumps = {
            1: {'Q': [0.0, 320.0, 570.0, 670.0, 897.0, 1099.0, 1298.0, 1597.3, 1823.2],
                'H': [76.81, 74.70, 71.61, 71.75, 72.29, 69.99, 67.95, 61.88, 55.79],
                'ETA': [0.00, 38.60, 55.55, 60.68, 72.28, 78.85, 83.12, 85.96, 85.07]},
            2: {'Q': [0.0, 240.0, 494.0, 694.0, 899.0, 1095.7, 1209.3, 1465.3],
                'H': [74.15, 72.15, 70.78, 69.34, 66.44, 62.75, 60.12, 52.57],
                'ETA': [0.00, 40.65, 63.22, 75.36, 83.02, 87.09, 88.32, 87.11]},
            3: {'Q': [0.0, 237.0, 470.0, 645.0, 797.3, 1000.3, 1208.7, 1410.0],
                'H': [63.84, 62.46, 61.98, 60.61, 58.68, 54.87, 50.07, 43.83],
                'ETA': [0.00, 38.28, 61.02, 72.36, 79.05, 84.94, 87.08, 84.91]},
            4: {'Q': [0.0, 195.0, 440.0, 659.7, 806.3, 998.7, 1204.0],
                'H': [64.13, 63.19, 61.62, 59.35, 56.61, 51.21, 43.71],
                'ETA': [0.00, 41.95, 67.50, 80.96, 84.84, 86.04, 83.42]}
        }

        A = math.pi * (44 ** 2) / 4
        rho, g = 1000, 9.81
        pressure_coeffs = [-2.713E-09, 0.000006504, 4.384E-07, -0.0001768]
        Qanlage = np.linspace(0, 2000, 2000)

        pump_curves = {
            i: {
                'H': np.polyval(np.polyfit(pumps[i]['Q'], pumps[i]['H'], 2), Qanlage),
                'ETA': np.polyfit(pumps[i]['Q'], pumps[i]['ETA'], 2)
            } for i in range(1, 5)
        }

        return pump_curves, Qanlage, A, rho, g, pressure_coeffs

    def _normalize_state(self, value: float, min_val: float, max_val: float) -> float:
        """Normalize a value to [0,1] range."""
        return (value - min_val) / (max_val - min_val)

    def _clean_data(self, year: str, month: str):
        """Load and clean data for a given year and month, matching WormsSimulation intent."""
        try:
            # Load all DataFrames
            data_time = pd.read_csv(f"{DATA_PATH}{year}/WaterConsumption/{month}.csv", delimiter=';', skip_blank_lines=True)[['Time']]
            data_date = pd.read_csv(f"{DATA_PATH}{year}/WaterConsumption/{month}.csv", delimiter=';', decimal=',', skip_blank_lines=True)[['Date']]
            data_waterConsumption = pd.read_csv(f"{DATA_PATH}{year}/WaterConsumption/{month}.csv", delimiter=';', decimal=',', skip_blank_lines=True)[['Netzverbrauch_pval']]
            data_qr = [pd.read_csv(f"{DATA_PATH}{year}/NP/NP{j}/Q/{month}.csv", delimiter=';', decimal=',', skip_blank_lines=True)[[f'NP_{j}_Volumenfluss_pval']] for j in range(1, 5)]

            # Log initial sizes
            logging.info(f"{year}/{month}: Initial rows - Time: {len(data_time)}, WC: {len(data_waterConsumption)}, NP1: {len(data_qr[0])}, NP2: {len(data_qr[1])}, NP3: {len(data_qr[2])}, NP4: {len(data_qr[3])}")

            # Align all DataFrames to the shortest water-related length (excluding NP1 granularity)
            min_length = min(len(data_time), len(data_date), len(data_waterConsumption), len(data_qr[1]), len(data_qr[2]), len(data_qr[3]))
            data_time = data_time.iloc[:min_length]
            data_date = data_date.iloc[:min_length]
            data_waterConsumption = data_waterConsumption.iloc[:min_length]
            data_qr[1:] = [df.iloc[:min_length] for df in data_qr[1:]]  # NP2, NP3, NP4

            # First cleaning: Remove duplicate minutes in NP1 flow (only data_qr[0])
            data_time_np1_q = pd.read_csv(f"{DATA_PATH}{year}/NP/NP1/Q/{month}.csv", delimiter=';', skip_blank_lines=True)[['Time']].iloc[:len(data_qr[0])]
            indices_to_drop = []
            m_i = None
            for i, row in data_time_np1_q.iterrows():
                _, m, _ = row['Time'].split(':')
                if i == 0:
                    m_i = m
                elif m_i == m:
                    indices_to_drop.append(i)
                else:
                    m_i = m
            data_qr[0].drop(indices_to_drop, inplace=True)
            data_qr[0].reset_index(drop=True, inplace=True)

            # Pad or truncate data_qr[0] to match min_length
            if len(data_qr[0]) < min_length:
                padding = pd.DataFrame({data_qr[0].columns[0]: [np.nan] * (min_length - len(data_qr[0]))})
                data_qr[0] = pd.concat([data_qr[0], padding], ignore_index=True)
            elif len(data_qr[0]) > min_length:
                data_qr[0] = data_qr[0].iloc[:min_length]
            data_qr[0].reset_index(drop=True, inplace=True)

            # Log after minute cleaning
            logging.info(f"{year}/{month}: After minute cleaning - NP1 rows: {len(data_qr[0])}, Others: {len(data_time)}")

            # Second cleaning: Remove duplicate hours where minute is 0 (all DataFrames)
            indices_to_drop = []
            h_i = None
            for i, row in data_time.iterrows():
                h, m, _ = row['Time'].split(':')
                if i == 0:
                    h_i = h
                elif int(m) == 0 and h_i == h:
                    j = 0
                    while i + j < len(data_time):
                        h_next, m_next, _ = data_time.iloc[i + j, 0].split(':')
                        if h_i != h_next:
                            break
                        indices_to_drop.append(i + j)
                        j += 1
                    h_i = h
                else:
                    h_i = h

            all_dfs = [data_time, data_date, data_waterConsumption] + data_qr
            for df in all_dfs:
                df.drop(indices_to_drop, inplace=True)
                df.reset_index(drop=True, inplace=True)

            # Log final size
            logging.info(f"{year}/{month}: Final rows after hour cleaning - {len(data_time)}")
            return data_time, data_date, data_waterConsumption, *data_qr

        except FileNotFoundError as e:
            logging.error(f"File not found for {year}/{month}: {e}")
            return None, None, None, None, None, None, None

    def _validate_water_consumption(self, wc, last_valid):
        """Validate and return water consumption, updating last valid value."""
        return last_valid if pd.isna(wc) or wc <= 0 else wc

    def _calculate_reward(self, action_idx: int, last_action: int, QBP: float, etaBP: float, HB: float, time_running: list, water_turnover: int) -> float:
        """Calculate normalized reward for the current step with water_turnover bonus."""
        p = 1 + time_running[action_idx] if action_idx == last_action or action_idx == 0 or time_running[action_idx] == 0 else 30 + time_running[action_idx]
        c = min(abs(HB - 50) if HB < 50 else (1 if HB >= 57 else 0), 1)
        if water_turnover == 0 and 50 <= HB < 53:
            c = -1
        reward = (math.exp(1 / (-QBP / etaBP)) - c * 10 + math.log(1 / p)) if action_idx > 0 else (-c * 10 + math.log(1 / p))
        return self._normalize_state(reward, MIN_REWARD, MAX_REWARD)

    def run(self):
        self.current_state = np.zeros((1, STATE_SIZE, STATE_SPACE))
        self.expanded_state = []
        time_running = [0] * 5
        last_action = 2
        water_turnover = 0
        HB = [53.22]

        pump_curves, Qanlage, A, rho, g, pressure_coeffs = self._initialize_pump_data()
        p1, p2, p3, p4 = pressure_coeffs

        total_transitions = 0
        with open("../replay_memory.txt", "w") as f:
            for indexy, year in enumerate(YEARS):
                for indexm, month in enumerate(MONTHS):
                    data_time, data_date, data_waterConsumption, *data_qr = self._clean_data(year, month)
                    if data_time is None:  # Skip if data loading failed
                        continue
                    data_length = len(data_qr[2])
                    logging.info(f"Processing {year}/{month}: {data_length} timesteps")

                    Hanlage = []
                    QBP = np.zeros(data_length)
                    HBP = np.zeros(data_length)
                    etaBP = np.ones(data_length)

                    last_valid_waterConsumption = 0
                    for i in range(data_length - 1):
                        wc = self._validate_water_consumption(data_waterConsumption.iloc[i, 0], last_valid_waterConsumption)
                        data_waterConsumption.iloc[i, 0] = wc
                        next_wc = self._validate_water_consumption(data_waterConsumption.iloc[i + 1, 0], wc)
                        data_waterConsumption.iloc[i + 1, 0] = next_wc
                        last_valid_waterConsumption = wc

                        if i == 0:
                            data_time.iloc[i, 0] = int(datetime.timedelta(**dict(zip(['hours', 'minutes', 'seconds'], map(int, data_time.iloc[i, 0].split(':'))))).total_seconds())
                            HB = [HB[-1]]

                        data_time.iloc[i + 1, 0] = int(datetime.timedelta(**dict(zip(['hours', 'minutes', 'seconds'], map(int, data_time.iloc[i + 1, 0].split(':'))))).total_seconds())

                        aa, bb, cc = p1 * wc + p2, p3 * wc + p4, HB[i]
                        Hanlage.append([aa * x**2 + bb * x + cc for x in Qanlage])

                        action_index = next((j + 1 for j, qr in enumerate(data_qr) if not pd.isna(qr.iloc[i, 0]) and qr.iloc[i, 0] != 0),
                                           last_action if pd.isna(data_qr[last_action - 1].iloc[i, 0]) else 0)

                        state = [
                            self._normalize_state(HB[i], MIN_TK1, MAX_TK1),
                            self._normalize_state(wc, MIN_CONSUMPTION, MAX_CONSUMPTION),
                            self._normalize_state(data_time.iloc[i, 0], MIN_TIME, MAX_TIME)
                        ] + MONTH_ENCODING[indexm] + ACTIONS[last_action] + [t / 1440 for t in time_running] + [water_turnover]
                        self.expanded_state.append(state)

                        if i < STATE_SIZE - 1:
                            action_index = last_action
                        else:
                            if i == STATE_SIZE - 1:
                                for j in range(STATE_SIZE):
                                    self.current_state[0][j] = self.expanded_state[j]
                            elif i >= STATE_SIZE:
                                self.expanded_state.pop(0)
                                for j in range(STATE_SIZE):
                                    self.current_state[0][j] = self.expanded_state[j]

                        if action_index > 0:
                            res = np.argmin(np.abs(np.array(Hanlage[i]) - pump_curves[action_index]['H']))
                            QBP[i], HBP[i] = Qanlage[res], pump_curves[action_index]['H'][res]
                            etaBP[i] = np.polyval(pump_curves[action_index]['ETA'], QBP[i])
                        else:
                            QBP[i], HBP[i], etaBP[i] = 0, 0, 1

                        HB.append(np.clip(HB[i] + (QBP[i] - wc) / 60 / A, MIN_TK1, MAX_TK1))

                        pbp = ((rho * g * QBP[i] / 3600 * HBP[i]) / 55 * etaBP[i]) / 1000
                        time_running[action_index] += 1

                        reward = self._calculate_reward(action_index, last_action, QBP[i], etaBP[i], HB[i], time_running, water_turnover)

                        if water_turnover == 0 and 50 <= HB[-1] < 53:
                            water_turnover = 1

                        next_state = [
                            self._normalize_state(HB[-1], MIN_TK1, MAX_TK1),
                            self._normalize_state(next_wc, MIN_CONSUMPTION, MAX_CONSUMPTION),
                            self._normalize_state(data_time.iloc[i + 1, 0], MIN_TIME, MAX_TIME)
                        ] + MONTH_ENCODING[indexm] + ACTIONS[action_index] + [t / 1440 for t in time_running] + [water_turnover]

                        last_action = action_index
                        action = ACTIONS[action_index]
                        done = data_time.iloc[i + 1, 0] == MAX_TIME

                        if done:
                            water_turnover = 0
                            time_running = [0] * 5

                        f.write(str((state, action, next_state, reward, done)) + '\n')
                        total_transitions += 1

            logging.info(f"Total transitions generated: {total_transitions}")

def parse_args():
    """Parse command-line arguments for RL algorithm selection."""
    parser = argparse.ArgumentParser(description="Water Distribution System Simulator with RL")
    parser.add_argument('--algorithm', type=str, default='rem', choices=['rem', 'ddqn', 'maxmin'],
                        help="RL algorithm to use: 'rem', 'ddqn', or 'maxmin' (default: rem)")
    return parser.parse_args()

def main():
    args = parse_args()
    try:
        simulator = WDS_sim(algorithm=args.algorithm)
    except ValueError as e:
        logging.error(f"Error: {e}")
    except Exception as e:
        logging.error(f"Unexpected error: {e}")

if __name__ == '__main__':
    main()
