from data_classes import PICalibData
import torch 
import numpy as np
import pandas as pd
from utils import MeanStdevFilter, prepare_data
from build_ctxt import BuildContext
from typing import List, Tuple
import jax
import jax.numpy as jnp
from typing import Union
import diffrax
import sys, os, gc

class DataProcessor:
    def __init__(self, params):
        self.params = params
        self.input_filter = None
        self.output_filter = None
        self.input_dim = None
        self.output_dim = None
        self.train_horizon = self.params['train_horizon']
        self.val_horizon = self.params['val_horizon']
        self.test_horizon = self.params['test_horizon']
        self.past_ts_ctxt = self.params['past_ts_ctxt']
        self.init_cond_ctxt = self.params['init_cond_ctxt']
        self.past_feat_ctxt = self.params['past_feat_ctxt']
        self._build_ctxt = BuildContext(self.params)
        self.params['ctxt_dim'] = 0
 
    def get_data(self) -> np.ndarray:
        ################## Read Train Data ##################
        #df = pd.read_csv(self.params['data_path'])

        #grouped = df.groupby('horizon')       

        if self.params['ode_name'] == 'lorenz':
            pred_cols = ['x_pred','y_pred','z_pred']
            gr_cols = ['x_gr','y_gr','z_gr']   
        elif self.params['ode_name'] == 'glycolytic':    
            pred_cols = ['S1_pred','S2_pred','S3_pred','S4_pred','S5_pred','S6_pred','S7_pred']
            gr_cols = ['S1_gr','S2_gr','S3_gr','S4_gr','S5_gr','S6_gr','S7_gr']    
        elif self.params['ode_name'] == 'LVolt':    
            pred_cols = ['x_pred','y_pred']
            gr_cols = ['x_gr','y_gr']    
        elif self.params['ode_name'] == 'lorenz96':
            pred_cols = ['x1_pred','x2_pred','x3_pred','x4_pred','x5_pred']
            gr_cols = ['x1_gr','x2_gr','x3_gr','x4_gr','x5_gr']    
        elif self.params['ode_name'] == 'FHNag':
            pred_cols = ['v_pred','w_pred']
            gr_cols = ['v_gr','w_gr']   
        elif self.params['ode_name'] in ['walker2d','halfcheetah']:
            pred_cols = [f'S{dim}_pred' for dim in range(17)]
            gr_cols = [f'S{dim}_gr' for dim in range(17)]    
        elif self.params['ode_name'] == 'pen_expert':
            pred_cols = [f'S{dim}_pred' for dim in range(45)]
            gr_cols = [f'S{dim}_gr' for dim in range(45)]            
        elif self.params['ode_name'] == 'hopper':
            pred_cols = [f'S{dim}_pred' for dim in range(11)]
            gr_cols = [f'S{dim}_gr' for dim in range(11)]                  
        elif self.params['ode_name'] == 'exchange':
            pred_cols = [f"pred_{i}" for i in range(8)]
            gr_cols = [f"gr_{i}" for i in range(8)]   
        elif self.params['ode_name'] == 'electricity':
            pred_cols = [f"pred_{i}" for i in range(10)]
            gr_cols = [f"gr_{i}" for i in range(10)] 
        elif self.params['ode_name'] == 'weather':
            pred_cols = [f"pred_{i}" for i in range(21)]
            gr_cols = [f"gr_{i}" for i in range(21)] 
        elif self.params['ode_name'] in ['ettm2','ettm1','etth1','etth2','ili']:
            pred_cols = [f"pred_{i}" for i in range(7)]
            gr_cols = [f"gr_{i}" for i in range(7)]                                                                                                                                                

        path_ = self.params['data_path'].split("/")[:-1]  
        path_ = ("/").join(path_)
        data_pred = np.load(path_ + "/pred_train.npy")
        data_gr = np.load(path_ + "/true_train.npy")
        #data_pred = self.moving_average(data_pred)
        data_tr = (data_pred, data_gr)
        print(f"The TRAIN data shape is: {data_pred.shape}")      

        data_pred = np.load(path_ + "/pred_val.npy")
        data_gr = np.load(path_ + "/true_val.npy")
        #data_pred = self.moving_average(data_pred)
        data_val = (data_pred, data_gr)
        print(f"The VAL data shape is: {data_pred.shape}")    

        path_ = self.params['data_path'].split("/")[:-1]  
        path_ = ("/").join(path_)
        data_pred = np.load(path_ + "/pred_test.npy")
        data_gr = np.load(path_ + "/true_test.npy")
        #data_pred = self.moving_average(data_pred)
        data_test = (data_pred, data_gr)
        print(f"The TEST data shape is: {data_pred.shape}")   

        self.input_dim = len(pred_cols) 
        self.output_dim = len(pred_cols)

        self.params['ctxt_dim'] += self.past_feat_ctxt*self.input_dim # input_dim+ctxt_dim
        if self.init_cond_ctxt:
            self.params['ctxt_dim'] += self.input_dim

        self.ctxt_dim = self.params['ctxt_dim']
        self.params['input_dim'] = self.input_dim
        self.params['output_dim'] = self.output_dim
        self.input_filter = MeanStdevFilter(self.ctxt_dim)
        self.output_filter = MeanStdevFilter(self.output_dim) 

        return (data_tr, data_val, data_test)

    def moving_average(self, x, window=15):
        """
        x: np.ndarray of shape (B, T, D)
        window: int, size of the moving average window
        returns: np.ndarray of shape (B, T, D)
        """
        assert window >= 1
        assert x.ndim == 3, "Input must be a 3D array of shape (B, T, D)"
        
        B, T, D = x.shape
        left_pad = window // 2
        right_pad = window - 1 - left_pad  # asymmetric if window is even

        smoothed = np.empty_like(x)
        
        for b in range(B):
            for d in range(D):
                padded = np.pad(x[b, :, d], (left_pad, right_pad), mode='edge')
                smoothed[b, :, d] = np.convolve(padded, np.ones(window) / window, mode='valid')
        
        return smoothed    

    def data_tuples(self, data: Union[Tuple[np.ndarray],Tuple[Tuple[np.ndarray]]]) -> Tuple[PICalibData]:

        data_tr, data_val, data_test = data

        ##################### Train data #####################
        sim = data_tr[0]
        gr = data_tr[1]

        del data_tr

        calib_train_true = PICalibData(
            X=gr[:, :-1, :],
            Y=gr[:, 1:, :]
        )

        T = calib_train_true.X.shape[1]
        #self.params['delta_t'] = 1.0 / (T - 1)

        end = calib_train_true.X.shape[1] * self.params['delta_t']
        calib_train_sim = PICalibData(
            X=sim[:, :-1, :],
            Y=sim[:, 1:, :],
            #timesteps=np.linspace(0.0, 1.0, T, dtype=np.float32),
            timesteps=torch.arange(0,end,self.params['delta_t']),
            error=gr[:, 1:, :] - sim[:, 1:, :]
        )

        del sim, gr   # free train memory
        
        ##################### Val data #####################
        sim = data_val[0]
        gr = data_val[1]

        calib_val_true = PICalibData(
            X=gr[:, :-1, :],
            Y=gr[:, 1:, :]
        )

        del data_val
        T = calib_val_true.X.shape[1]
        end = calib_val_true.X.shape[1] * self.params['delta_t']
        calib_val_sim = PICalibData(
            X=sim[:, :-1, :],
            Y=sim[:, 1:, :],
            #timesteps=np.linspace(0.0, 1.0, T, dtype=np.float32),
            timesteps=torch.arange(0,end,self.params['delta_t']),
            error=gr[:, 1:, :] - sim[:, 1:, :]
        )

        del sim, gr   # free val memory
       
        ##################### Test data #####################
        sim = data_test[0]
        gr = data_test[1]

        calib_test_true = PICalibData(
            X=gr[:, :-1, :],
            Y=gr[:, 1:, :]
        )

        T = calib_test_true.X.shape[1]
        end = calib_test_true.X.shape[1] * self.params['delta_t']
        calib_test_sim = PICalibData(
            X=sim[:, :-1, :],
            Y=sim[:, 1:, :],
            #timesteps=np.linspace(0.0, 1.0, T, dtype=np.float32),
            timesteps=torch.arange(0,end,self.params['delta_t']),
            error=gr[:, 1:, :] - sim[:, 1:, :]
        )

        del sim, gr   # free test memory 
   
        return (calib_train_true, calib_val_true, calib_test_true, calib_train_sim, calib_val_sim, calib_test_sim)   


    def normalize_calib_data(self, calib_data: Tuple[PICalibData]):

        # --- Use views directly ---
        gr_tr_input_ctxt = calib_data[0].X_ctx
        gr_val_input_ctxt = calib_data[1].X_ctx
        gr_test_input_ctxt = calib_data[2].X_ctx
        sim_tr_input_ctxt = calib_data[3].X_ctx
        sim_val_input_ctxt = calib_data[4].X_ctx
        sim_test_input_ctxt = calib_data[5].X_ctx

        # --- Update filters using only training data ---
        self.calculate_mean_var(gr_tr_input_ctxt, self.input_dim)

        # --- Normalize splits ---
        gr_tr_ctxt_filter, sim_tr_ctxt_filter = self.normalize_data(gr_tr_input_ctxt, sim_tr_input_ctxt, self.input_dim)
        del gr_tr_input_ctxt, sim_tr_input_ctxt; gc.collect()

        gr_val_ctxt_filter, sim_val_ctxt_filter = self.normalize_data(gr_val_input_ctxt, sim_val_input_ctxt, self.input_dim)
        del gr_val_input_ctxt, sim_val_input_ctxt; gc.collect()

        gr_test_ctxt_filter, sim_test_ctxt_filter = self.normalize_data(gr_test_input_ctxt, sim_test_input_ctxt, self.input_dim)
        del gr_test_input_ctxt, sim_test_input_ctxt; gc.collect()

        # Save normalized X_ctx
        calib_data[0].X_ctx, calib_data[1].X_ctx, calib_data[2].X_ctx = gr_tr_ctxt_filter, gr_val_ctxt_filter, gr_test_ctxt_filter
        calib_data[3].X_ctx, calib_data[4].X_ctx, calib_data[5].X_ctx = sim_tr_ctxt_filter, sim_val_ctxt_filter, sim_test_ctxt_filter

        # Save normalized X (subset of channels)
        calib_data[0].X = calib_data[0].X_ctx[..., :self.input_dim]
        calib_data[1].X = calib_data[1].X_ctx[..., :self.input_dim]
        calib_data[2].X = calib_data[2].X_ctx[..., :self.input_dim]
        calib_data[3].X = calib_data[3].X_ctx[..., :self.input_dim]
        calib_data[4].X = calib_data[4].X_ctx[..., :self.input_dim]
        calib_data[5].X = calib_data[5].X_ctx[..., :self.input_dim]

        # --- Errors normalization ---
        error_tr, error_val, error_test = calib_data[3].error, calib_data[4].error, calib_data[5].error

        self.calculate_mean_var_out(error_tr, self.output_dim)

        calib_data[3].error = self.normalize_data_out(error_tr, self.output_dim); del error_tr; gc.collect()
        calib_data[4].error = self.normalize_data_out(error_val, self.output_dim); del error_val; gc.collect()
        calib_data[5].error = self.normalize_data_out(error_test, self.output_dim); del error_test; gc.collect()

        return calib_data  

    def calculate_mean_var_out(self, errors: np.ndarray, output_dim: int) -> None:

        errors = errors.reshape(-1, output_dim)

        total_points = errors.shape[0]

        #for i in range(total_points):
        #    self.output_filter.update(errors[i,:])
        self.output_filter.update(errors)

        self.params['output_filter'] = self.output_filter  

        return    

    def normalize_data_out(self, errors: np.ndarray, output_dim: int) -> np.ndarray:
        errors_seq_len = errors.shape[1]
        errors = errors.reshape(-1, output_dim)

        out_filter = prepare_data(errors, self.output_filter)
        del errors; gc.collect()

        out_filter = out_filter.reshape(-1, errors_seq_len, output_dim)
        return out_filter    

    def calculate_mean_var(self, gr_input_data: np.ndarray, input_dim: int) -> None:

        gr_input_data = gr_input_data.reshape(-1, input_dim)

        total_points = gr_input_data.shape[0]

        #for i in range(total_points):
        #    self.input_filter.update(gr_input_data[i,:])
        self.input_filter.update(gr_input_data)

        self.params['input_filter'] = self.input_filter  

        return 

    def normalize_data(self, gr_input_data: np.ndarray, sim_input_data: np.ndarray, input_dim: int) -> Tuple[np.ndarray]:
        gr_seq_len = gr_input_data.shape[1]
        sim_seq_len = sim_input_data.shape[1]
        
        gr_input_data = gr_input_data.reshape(-1, input_dim)
        sim_input_data = sim_input_data.reshape(-1, input_dim)

        gr_input_filter = prepare_data(gr_input_data, self.input_filter)
        sim_input_filter = prepare_data(sim_input_data, self.input_filter)
        del gr_input_data, sim_input_data; gc.collect()

        gr_input_filter = gr_input_filter.reshape(-1, gr_seq_len, input_dim)
        sim_input_filter = sim_input_filter.reshape(-1, sim_seq_len, input_dim)

        return (gr_input_filter, sim_input_filter)

    def np_to_jnp(self, calib_data: Tuple[PICalibData]):

        for i in range(3):
            calib_data[i].X = jnp.asarray(calib_data[i].X, dtype=jnp.float32)
            calib_data[i].Y = jnp.asarray(calib_data[i].Y, dtype=jnp.float32)
            calib_data[i].X_ctx = jnp.asarray(calib_data[i].X_ctx, dtype=jnp.float32)

        for i in range(3,6):
            calib_data[i].X = jnp.asarray(calib_data[i].X, dtype=jnp.float32)
            calib_data[i].Y = jnp.asarray(calib_data[i].Y, dtype=jnp.float32)
            calib_data[i].X_ctx = jnp.asarray(calib_data[i].X_ctx, dtype=jnp.float32)
            calib_data[i].error = jnp.asarray(calib_data[i].error, dtype=jnp.float32)
            calib_data[i].timesteps = jnp.asarray(calib_data[i].timesteps, dtype=jnp.float32)

        return calib_data

    def spline_coeffs(self, calib_data:Tuple[PICalibData], chunk_size=100):
        """Use cubic spline with chunked vmap to avoid OOM."""

        def compute_coeffs(X_ctx, ts):
            # f returns a tuple of 4 arrays
            f = jax.vmap(lambda x: diffrax.backward_hermite_coefficients(ts, x))

            coeffs_chunks = None
            N = X_ctx.shape[0]

            for i in range(0, N, chunk_size):
                end = min(i+chunk_size, N)
                out = f(X_ctx[i:end])   # tuple of 4 arrays

                if coeffs_chunks is None:
                    # initialize list of lists, one per tuple element
                    coeffs_chunks = [[o] for o in out]
                else:
                    for j in range(len(out)):
                        coeffs_chunks[j].append(out[j])

            # concatenate each coefficient separately
            coeffs = tuple(jnp.concatenate(chunks, axis=0) for chunks in coeffs_chunks)
            return coeffs   # (4 arrays)

        calib_data[0].X_ctx_coeffs = compute_coeffs(calib_data[0].X_ctx, calib_data[3].timesteps)
        calib_data[1].X_ctx_coeffs = compute_coeffs(calib_data[1].X_ctx, calib_data[4].timesteps)
        calib_data[2].X_ctx_coeffs = compute_coeffs(calib_data[2].X_ctx, calib_data[5].timesteps)
        calib_data[3].X_ctx_coeffs = compute_coeffs(calib_data[3].X_ctx, calib_data[3].timesteps)
        calib_data[4].X_ctx_coeffs = compute_coeffs(calib_data[4].X_ctx, calib_data[4].timesteps)
        calib_data[5].X_ctx_coeffs = compute_coeffs(calib_data[5].X_ctx, calib_data[5].timesteps)

        return calib_data
        
    def no_ctxt(self, calib_data: Tuple[PICalibData]):
        """
        In case of no context, simply copy the calib_data[0].X to
        calib_data[0].X_ctxt 
        """
        calib_data[0].X_ctx = calib_data[0].X.copy()
        calib_data[1].X_ctx = calib_data[1].X.copy()
        calib_data[2].X_ctx = calib_data[2].X.copy()
        calib_data[3].X_ctx = calib_data[3].X.copy()
        calib_data[4].X_ctx = calib_data[4].X.copy()
        calib_data[5].X_ctx = calib_data[5].X.copy()        

        return calib_data       

    def build_context(self, calib_data: Tuple[PICalibData], add_time: bool):

        # CAUTION: calib_data gets modified in-place so be careful sending it multiple times to self._build_ctxt
        # CAUTION: DON'T MODIFY X ENTRY OF PICalibData TUPLE AS IT IS USEFUL FOR ADDING INITIAL CONDITION CONTEXT

        if not add_time:
            if calib_data[0].X_ctx is None: # True if self.past_feat_ctxt == 1
                calib_data = self.no_ctxt(calib_data)
                print("X has been copied to X_ctxt!") 

            if self.init_cond_ctxt:
                calib_data = self._build_ctxt.add_init_cond(calib_data)
            else:
                print("Initial condition won't be used as a context!")  
        else:        
            if self.past_ts_ctxt == 1:
                calib_data = self._build_ctxt.add_timestep_np(calib_data)
            elif self.past_ts_ctxt > 1: 
                calib_data = self._build_ctxt.add_k_timestep(calib_data)
            else: 
                print("The past timesteps will NOT be used as context!")  

        print(f"The context length is: {calib_data[0].X_ctx.shape[-1]}")    

        return calib_data      