# Importing all the necessary packages
import configparser
import os
import pandas as pd
import numpy as np

class RetrieveData:
    def __init__(self, ini_file_name):
        """
        Initialize the class with configuration file name
        """
        # config.ini file name
        self.ini_file_name = ini_file_name

        # Loading the configuration file
        self.config = configparser.ConfigParser()
        self.config.read(self.ini_file_name)

        # Number of components
        self.num_components = int(self.config['COMPONENTS']['num_components'])

        # Total time for the time series
        self.total_time = int(self.config['COMPONENTS']['total_time'])

        # Global learning rate
        self.global_learning_rate = float(eval(self.config['GLOBAL_PARAMETERS']['gamma_g']))

        # Global regularization rate
        self.global_regularization_rate = float(eval(self.config['GLOBAL_PARAMETERS']['lambda_g']))

        # Proportion of training data
        self.training_prop = float(eval(self.config['TRAINING']['train_set_prop']))

        # Training time
        self.training_time = int(np.ceil(self.training_prop * self.total_time))

        # Validation time
        # self.validation_time = 10000  # Commented out

        # Initializing data from each component
        self.local_learners_pack = None

        # Initializing diagonal matrix data for the global learner
        self.global_learners_pack = None

        # Initializing data for the HMatrix
        self.HMatrix_pack = None

        # Validation pack
        # self.validation_pack = None  # Commented out

        # Initializing vector that stores size of states in each component
        self.comp_size_vec = None

        # Size of the global state vector
        self.global_state_size = None

        # Output size vector 
        self.output_size_vec = None

        # Global output size
        self.global_output_size = None

        # Output data location 
        self.results_location = self.config['LOCATION']['results_location']

        # Validation results location 
        # self.valid_location = self.config['LOCATION']['valid_location']  # Commented out

        # Centralized Kalman Filter pack
        self.CKF_pack = None

        # Global Y
        self.global_Y = None

        # Global Y valid   
        # self.global_Y_valid = None  # Commented out

        # Output start indices
        self.output_start_vec = np.zeros((self.num_components, 1), dtype=int)

        # Preparing the data
        self.PrepareData()

    def PrepareData(self):
        """
        Get the data for each component.
        
        Outputs:
        - M: Number of components
        - T: Length of time series data
        - components: A nested dictionary containing properties of all components and the corresponding observation time series
        """
        try:
            # Getting the data folder location from config file
            data_location = self.config['LOCATION']['data_location']
            # valid_data_location = self.config['LOCATION']['valid_data_location']  # Commented out

            # Getting the learning rate of local loss
            eta_l_str = self.config['LOCAL_PARAMETERS']['eta_l']
            eta_l = [float(x) for x in eta_l_str.split(',')]

            # Getting the learning rate of global loss for each component
            eta_g_str = self.config['LOCAL_PARAMETERS']['eta_g']
            eta_g = [float(x) for x in eta_g_str.split(',')]

            # Getting the local regularization parameter
            lambda_l_str = self.config['LOCAL_PARAMETERS']['lambda_l']
            lambda_l = [float(x) for x in lambda_l_str.split(',')]

            # Raising an exception if the learning rate and regularization parameter vectors are not the correct size
            if len(eta_l) != self.num_components or len(lambda_l) != self.num_components or len(eta_g) != self.num_components:
                raise ValueError(f"Length of eta_l vector (= {len(eta_l)}) and/or length of lambda_l vector (= {len(lambda_l)}) \
                                 and/or length of eta_g vector (= {len(eta_g)}) != number of components (= {self.num_components})")

            # Dictionary to store all component data and global learner package
            self.local_learners_pack = {}
            self.global_learners_pack = {}
            self.HMatrix_pack = {}
            self.CKF_pack = {}
            # self.validation_pack = {}  # Commented out
            self.global_Y = {}
            # self.global_Y_valid = {}  # Commented out
            A_mm = {}
            # Initializing total size of state vector
            self.comp_size_vec = np.zeros((self.num_components, 1), dtype=int)
            self.global_state_size = 0

            self.output_size_vec = np.zeros((self.num_components, 1), dtype=int)
            self.global_output_size = 0

            A_complete_path = os.path.join(data_location, 'A_complete.csv')
            C_complete_path = os.path.join(data_location, 'C_complete.csv')
            B_complete_path = os.path.join(data_location, 'B_complete.csv')
            Q_complete_path = os.path.join(data_location, 'Q_complete.csv')
            R_complete_path = os.path.join(data_location, 'R_complete.csv')
            x0_complete_path = os.path.join(data_location, 'x0_complete.csv')


            df_A_complete = pd.read_csv(A_complete_path, header=None)
            df_C_complete = pd.read_csv(C_complete_path, header=None)
            df_B_complete = pd.read_csv(B_complete_path, header=None)
            df_Q_complete = pd.read_csv(Q_complete_path, header=None)
            df_R_complete = pd.read_csv(R_complete_path, header=None)
            df_x0_complete = pd.read_csv(x0_complete_path, header=None)


            A_complete = df_A_complete.to_numpy()
            C_complete = df_C_complete.to_numpy()
            B_complete = df_B_complete.to_numpy()
            Q_complete = df_Q_complete.to_numpy()
            R_complete = df_R_complete.to_numpy()
            x0_complete = df_x0_complete.to_numpy().T
            
            # A_complete_valid_path = os.path.join(valid_data_location, 'A_complete.csv')  # Commented out
            # df_A_complete_valid = pd.read_csv(A_complete_valid_path, header=None)  # Commented out
            # A_complete_valid = df_A_complete_valid.to_numpy()  # Commented out

            for m in range(self.num_components):
                A_matrix_path = os.path.join(data_location, f'C{m+1}/A.csv')
                C_matrix_path = os.path.join(data_location, f'C{m+1}/C.csv')
                Y_path = os.path.join(data_location, f'C{m+1}/Y.csv')
                B_path = os.path.join(data_location, f'C{m+1}/B.csv')
                Q_path = os.path.join(data_location, f'C{m+1}/Q.csv')
                R_path = os.path.join(data_location, f'C{m+1}/R.csv')
                x0_path = os.path.join(data_location, f'C{m+1}/x0.csv')

                df_A = pd.read_csv(A_matrix_path, header=None)
                df_C = pd.read_csv(C_matrix_path, header=None)
                df_Y = pd.read_csv(Y_path, header=None)
                df_B = pd.read_csv(B_path, header=None)
                df_Q = pd.read_csv(Q_path, header=None)
                df_R = pd.read_csv(R_path, header=None)
                df_x0 = pd.read_csv(x0_path, header=None)

                A = df_A.to_numpy()
                C = df_C.to_numpy()
                Y = df_Y.to_numpy().T
                B = df_B.to_numpy()
                Q = df_Q.to_numpy()
                R = df_R.to_numpy()
                x0 = df_x0.to_numpy().T
                P0 = Q

                if Y.shape[1] != self.total_time: 
                    raise ValueError(f"Length of time series from config file = {self.total_time} != Length of observation series {Y.shape[1]}")
                
                Y_train = Y[:, :self.training_time]

                # Y_valid_path = os.path.join(valid_data_location, f'C{m+1}/Y.csv')  # Commented out
                # df_Y_valid = pd.read_csv(Y_valid_path, header=None)  # Commented out
                # Y_valid = df_Y_valid.to_numpy().T  # Commented out
                
                # if Y_valid.shape[1] != self.validation_time:  # Commented out
                #     raise ValueError(f"Length of time series from config file = {self.validation_time} != Length of observation series {Y_valid.shape[1]}")  # Commented out

                d_m, p_m = C.shape

                self.comp_size_vec[m, 0] = p_m
                self.global_state_size += p_m

                # Taking output size
                self.output_size_vec[m, 0] = d_m
                self.global_output_size += d_m

                # Initial values and noise parameters for the local model
                B = np.zeros((p_m, p_m))
                Q = 0.0005 * np.eye(p_m)
                R = 0.0005 * np.eye(d_m)
                P0 = Q
                x0 = np.zeros((p_m, 1))

                self.local_learners_pack[f'comp_{m+1}'] = {
                    'B': B,
                    'Q': Q,
                    'R': R,
                    'P0': P0,
                    'x0': x0,
                    'A': A,
                    'C': C,
                    'Y': Y_train,
                    'd_m': d_m,
                    'p_m': p_m,
                    'eta_l': eta_l[m],
                    'eta_g': eta_g[m],
                    'lambda_l': lambda_l[m],
                }

                # self.validation_pack[f'comp_{m+1}'] = {  # Commented out
                #     'B': B,
                #     'Q': Q,
                #     'R': R,
                #     'P0': P0,
                #     'x0': x0,
                #     'A': A,
                #     'C': C,
                #     'Y': Y_valid,
                #     'd_m': d_m,
                #     'p_m': p_m,
                #     'eta_l': eta_l[m],
                #     'eta_g': eta_g[m],
                #     'lambda_l': lambda_l[m],
                # }
                self.global_Y[f'{m+1}'] = Y
                # self.global_Y_valid[f'{m+1}'] = Y_valid  # Commented out
                A_mm[f'{m+1}{m+1}'] = A

            global_Y = np.zeros((self.global_output_size, self.total_time))
            # global_Y_valid = np.zeros((self.global_output_size, self.total_time))  # Commented out
            d_start_idx = 0
            for m in range(self.num_components):
                global_Y[d_start_idx:d_start_idx + self.output_size_vec[m, 0], :] = self.global_Y[f'{m+1}']
                # global_Y_valid[d_start_idx:d_start_idx + self.output_size_vec[m, 0], :] = self.global_Y_valid[f'{m+1}']  # Commented out
                if m > 0:
                    self.output_start_vec[m, 0] = d_start_idx
                    
                d_start_idx = d_start_idx + self.output_size_vec[m, 0]

            self.global_learners_pack['A_mm'] = A_mm
            self.global_learners_pack['gamma_g'] = self.global_learning_rate
            self.global_learners_pack['lambda_g'] = self.global_regularization_rate

            self.HMatrix_pack['comp_data'] = self.local_learners_pack
            self.HMatrix_pack['gamma_g'] = self.global_learning_rate
            self.HMatrix_pack['lambda_g'] = self.global_regularization_rate
            self.HMatrix_pack['p'] = self.global_state_size
            self.HMatrix_pack['p_vec'] = self.comp_size_vec
            self.HMatrix_pack['d'] = self.global_output_size
            self.HMatrix_pack['d_vec'] = self.output_size_vec     
            self.HMatrix_pack['M'] = self.num_components
            self.HMatrix_pack['total_time'] = self.total_time

            self.CKF_pack['A_complete'] = A_complete
            self.CKF_pack['C_complete'] = C_complete
            self.CKF_pack['Y'] = global_Y
            self.CKF_pack['B'] = B_complete
            self.CKF_pack['Q'] = Q_complete
            self.CKF_pack['R'] = R_complete
            self.CKF_pack['P0'] = Q_complete
            self.CKF_pack['x0'] = x0_complete
            # self.CKF_pack['Y_valid'] = global_Y_valid  # Commented out
            # self.CKF_pack['A_complete_valid'] = A_complete_valid  # Commented out

        except FileNotFoundError as e:
            print(f"An error occurred: {e}")

        except ValueError as e:
            print(f"An error occurred: {e}")