
import os
import gymnasium as gym
import numpy as np
import pandas as pd
import torch  
from gymnasium import spaces
from sklearn.metrics import average_precision_score

from src.utils.utils_env import read_random_ts
from src.utils.data_loader import DataLoader
from src.utils.scores_loader import ScoresLoader

from torchmetrics.classification import AveragePrecision
from src.validation import apply_wts_batch

from config import path_raw, path_data, path_scores, metrics

class EnvNewAPI_raw(gym.Env):
    def __init__(self, window_size, data, device="cpu"):
        """
         Args:
            window_size (int): Length of the windows.
            data (np.array): list of time series, scores and labels.
            path_score (str): Path to the anomaly scores.
            path_data (str): Path to the time series data."""
        super(EnvNewAPI_raw, self).__init__()
        
        self.data = data
        self.n_detectors = 12
        self.window_size = window_size
        self.device = device

        self.observation_space = spaces.Box(low=-np.float32(np.inf), high=np.float32(np.inf), shape=(1, 1, window_size), dtype=np.float32)
        
        # Action space: continuous weights in [-1, 1] for each detector
        self.action_space = spaces.Box(low=0.0, high=1.0, shape=(self.n_detectors,), dtype=np.float32)
        
        # Action space: continuous weights in [0, 1] for each detector
        # self.action_space = spaces.Box(low=0.0, high=1.0, shape=(self.n_detectors,), dtype=np.float32)
        
        self.index_window = 0
        self.current_window = None
        self.current_scores = None
        self.current_label = None

        self.n_average = 0  # Number of average to do on the scores

    def set_res_dict_plots(self, res_dict):
        """
        Set the result dictionary to store evaluation results.
        """
        self.res_dict = res_dict

    def reset(self,seed=None, options=None):
        self.index_window += 1
        if self.index_window >= self.data.__len__():
            self.index_window = 0  # ou np.random.randint(0, len(self.data))

        # Sample a random time serie in the train or test depending if on training or testing
        window, all_scores, label, dataset = self.data.__getitem__(self.index_window) 
        self.current_scores, self.current_label, self.current_window, self.dataset = all_scores, label, window, dataset   
        state = self.current_window

        return np.array(state),{}

    def step(self, action):
        # Normalize weights for stability 
        weights = action / (np.linalg.norm(action) + 1e-8)

        if np.linalg.norm(action) == 0:
            weights = np.ones_like(weights) / len(weights)
            self.n_average += 1
        
        # Weighted anomaly score, self.current_scores shape : [ts_len, n_detectors]
        combined_score = np.average(self.current_scores, axis=1, weights=weights)

        # Predict anomalies
        AUCPR = average_precision_score(self.current_label, combined_score)

        # Single step episode
        terminated = True
        truncated = False
        info = {"AUC-PR": AUCPR}
        reward = AUCPR

        # Dummy next state
        next_state = self.reset()
        
        return next_state[0], reward, terminated, truncated, info

    def render(self, mode="human"):
        pass

class EnvNewAPI_feat(gym.Env):
    def __init__(self, data):
        """
         Args:
            methods (list): List of anomaly detection methods to use.
            data (np.array): list of time series with scores and labels.
            path_score (str): Path to the anomaly scores.
            path_data (str): Path to the time series data."""
        super(EnvNewAPI_feat, self).__init__()
        
        self.data = data
        self.n_detectors = 12
        self.feat_size = 22
        
        self.observation_space = spaces.Box(low=-np.float32(np.inf), high=np.float32(np.inf), shape=(22,), dtype=np.float32) #shape (batch_size, n_features)
        
        # Action space: continuous weights in [0, 1] for each detector
        self.action_space = spaces.Box(low=0.0, high=1.0, shape=(self.n_detectors,), dtype=np.float32)
        
        self.index_feat = 0
        self.current_feat = None
        self.current_scores = None
        self.current_label = None

        self.eval_mode = False
        self.aucpr_list = []

    def reset(self, seed=None, options=None):
        # Sample a random time serie in the train or test depending if on training or testing
        feat, all_scores, label = self.data.__getitem__(self.index_feat) 
        self.current_scores, self.current_label, self.current_feat = all_scores, label, feat   
        state = self.current_feat

        self.index_feat += 1
        if self.index_feat >= self.data.__len__():
            self.index_feat = 0  

        return np.array(state),{}

    def step(self, action):
        # Normalize weights to sum to 1
        weights = action / (np.linalg.norm(action) + 1e-8)
           
        # Weighted anomaly score, self.current_scores shape : [ts_len, n_detectors]
        if np.sum(weights) == 0:
            weights = np.ones_like(weights) / len(weights)

        combined_score = np.average(self.current_scores, axis=1, weights=weights)

        # Predict anomalies
        AUCPR = average_precision_score(self.current_label, combined_score)
        reward = AUCPR

        # Single step episode
        terminated = True
        truncated = False
        info = {"AUC-PR": AUCPR}

        if self.eval_mode :
            self.aucpr_list.append(AUCPR)
            self.res_dict_plots[f"{self.eval_set}_AUC-PR_box"].append(self.aucpr_list)

        # Dummy next state
        next_state = self.reset()
        
        return next_state[0], reward, terminated, truncated, info

    def render(self, mode="human"):
        pass


