from typing import *
import numpy as np
import pandas as pd

class EquityEnv:
    def __init__(
            self,
            daily_price_data_path: str,
            # env parameters
            train_beg_date: int=19780101,
            train_end_date: int=20201231, # inclusive
            val_beg_date:   int=20210101,
            val_end_date:   int=20211231, # inclusive
            test_beg_date:  int=20220101,
            test_end_date:  int=20221231, # inclusive
        ):
        # dates
        self.train_beg_date: int = train_beg_date
        self.train_end_date: int = train_end_date
        self.val_beg_date:   int = val_beg_date
        self.val_end_date:   int = val_end_date
        self.test_beg_date:  int = test_beg_date
        self.test_end_date:  int = test_end_date
        self._sanity_check()
        # load npz files
        daily_price_data = np.load(daily_price_data_path)

        """
        compustat_tensor: np.ndarray of shape (dates, permnos, accounting_vars), (745, 36669, 31)
        accounting_vars: List of accounting variables (31,)
        crsp_tensor: np.ndarray of shape (dates, permnos, monthly_price_vars), (745, 36669, 10)
        monthly_price_vars: List of monthly price variables (10,)
        dates: List of months (745,)
        ff_3f_daily: np.ndarray of shape (daily_dates, ff_3f), (15628, 3)
        rf_daily: np.ndarray of shape (daily_dates, ret), (15628, 1) 
        daily_crsp_tensor: np.ndarray of shape (daily_dates, permnos, daily_price_vars), (15628, 36669, 5)
        daily_price_vars: List of daily price variables (5,)
        daily_dates: np.ndarray of shape (15628,)
        compustat_yr_tensor_filled: np.ndarray of shape (dates, permnos, yr_accounting_vars), (745, 36669, 16)
        compustat_yr_tensor: np.ndarray of shape (dates, permnos, yr_accounting_vars), (745, 36669, 16)
        yr_accounting_vars: List of yearly accounting variables (16,)
        ff_monthly_data: np.ndarray of shape (dates, ff), (745, 3)
        rf_monthly_rate: np.ndarray of shape (dates, rf_rate), (745, 1)
        returns: np.ndarray of shape (dates, permnos), (745, 36669)
        permnos: np.ndarray of shape (permnos,), (36669,)
        pmno_merge_mat: np.ndarray of shape (permnos, permnos), (36669, 36669)
        """
        self.compustat_tensor   : np.ndarray  = daily_price_data['compustat_tensor']
        self.accounting_vars    : List[str]   = daily_price_data['accounting_vars']
        self.crsp_tensor        : np.ndarray  = daily_price_data['crsp_tensor']
        self.monthly_price_vars : List[str]   = daily_price_data['monthly_price_vars']
        self.dates              : List[int]   = daily_price_data['dates']
        self.ff_3f_daily        : np.ndarray  = daily_price_data['ff_3f_daily']
        self.rf_daily           : np.ndarray  = daily_price_data['rf_daily']
        self.daily_crsp_tensor  : np.ndarray  = daily_price_data['daily_crsp_tensor']
        self.daily_price_vars   : List[str]   = daily_price_data['daily_price_vars']
        self.daily_dates        : np.ndarray  = daily_price_data['daily_dates']
        self.compustat_yr_tensor_filled: np.ndarray  = daily_price_data['compustat_yr_tensor_filled']
        self.compustat_yr_tensor: np.ndarray  = daily_price_data['compustat_yr_tensor']
        self.yr_accounting_vars : List[str]   = daily_price_data['yr_accounting_vars']
        self.ff_monthly_data    : np.ndarray  = daily_price_data['ff_monthly_data']
        self.rf_monthly_rate    : np.ndarray  = daily_price_data['rf_monthly_rate']
        self.returns            : np.ndarray  = daily_price_data['returns']
        self.permnos            : np.ndarray  = daily_price_data['permnos']
        #self.pmno_merge_mat     : np.ndarray  = daily_price_data['pmno_merge_mat']
        # given each date, find where it is in the daily_dates array
        # note the difference between < and <=
        self.train_beg_idx = np.argmin((self.daily_dates < self.train_beg_date)) 
        self.train_end_idx = np.argmin((self.daily_dates <= self.train_end_date))
        self.val_beg_idx   = np.argmin((self.daily_dates < self.val_beg_date))
        self.val_end_idx   = np.argmin((self.daily_dates <= self.val_end_date))
        self.test_beg_idx  = np.argmin((self.daily_dates < self.test_beg_date))
        self.test_end_idx  = np.argmin((self.daily_dates <= self.test_end_date))
        self._date_check()



    def load_raw(self):
        return {
            "train_daily_dates"         : self.daily_dates[self.train_beg_idx:self.train_end_idx],
            # daily_crsp_tensor[:, :, 3] is the returns
            "train_daily_crsp_tensor"   : self.daily_crsp_tensor[self.train_beg_idx:self.train_end_idx, :, 3],
            "train_rf_daily"            : self.rf_daily[self.train_beg_idx:self.train_end_idx, :],

            "val_daily_dates"           : self.daily_dates[self.val_beg_idx:self.val_end_idx],
            # daily_crsp_tensor[:, :, 3] is the returns
            "val_daily_crsp_tensor"     : self.daily_crsp_tensor[self.val_beg_idx:self.val_end_idx, :, 3],
            "val_rf_daily"              : self.rf_daily[self.val_beg_idx:self.val_end_idx, :],

            "test_daily_dates"          : self.daily_dates[self.test_beg_idx:self.test_end_idx],
            # daily_crsp_tensor[:, :, 3] is the returns
            "test_daily_crsp_tensor"    : self.daily_crsp_tensor[self.test_beg_idx:self.test_end_idx, :, 3],
            "test_rf_daily"             : self.rf_daily[self.test_beg_idx:self.test_end_idx, :],
        } 
    
    def load_raw_with_volume(self):
        return {
            "train_daily_dates"         : self.daily_dates[self.train_beg_idx:self.train_end_idx],
            # daily_crsp_tensor[:, :, 3] is the returns
            "train_daily_crsp_tensor"   : self.daily_crsp_tensor[self.train_beg_idx:self.train_end_idx, :, 3:],
            "train_rf_daily"            : self.rf_daily[self.train_beg_idx:self.train_end_idx, :],

            "val_daily_dates"           : self.daily_dates[self.val_beg_idx:self.val_end_idx],
            # daily_crsp_tensor[:, :, 3] is the returns
            "val_daily_crsp_tensor"     : self.daily_crsp_tensor[self.val_beg_idx:self.val_end_idx, :, 3:],
            "val_rf_daily"              : self.rf_daily[self.val_beg_idx:self.val_end_idx, :],

            "test_daily_dates"          : self.daily_dates[self.test_beg_idx:self.test_end_idx],
            # daily_crsp_tensor[:, :, 3] is the returns
            "test_daily_crsp_tensor"    : self.daily_crsp_tensor[self.test_beg_idx:self.test_end_idx, :, 3:],
            "test_rf_daily"             : self.rf_daily[self.test_beg_idx:self.test_end_idx, :],
        }
    
    def load_raw_full(self):
        # Return all the raw data
        return {
            "compustat_tensor": self.compustat_tensor,
            "accounting_vars": self.accounting_vars,
            "crsp_tensor": self.crsp_tensor,
            "monthly_price_vars": self.monthly_price_vars,
            "dates": self.dates,
            "ff_3f_daily": self.ff_3f_daily,
            "rf_daily": self.rf_daily,
            "daily_crsp_tensor": self.daily_crsp_tensor,
            "daily_price_vars": self.daily_price_vars,
            "daily_dates": self.daily_dates,
            "compustat_yr_tensor_filled": self.compustat_yr_tensor_filled,
            "compustat_yr_tensor": self.compustat_yr_tensor,
            "yr_accounting_vars": self.yr_accounting_vars,
            "ff_monthly_data": self.ff_monthly_data,
            "rf_monthly_rate": self.rf_monthly_rate,
            "returns": self.returns,
            "permnos": self.permnos,
            #"pmno_merge_mat": self.pmno_merge_mat,
        }



    def _load(self, lookback_window: int, daily_dates, daily_crsp_tensor, rf_daily):
        # generate training data with lookback window
        X_daily_dates, X, Y_daily_dates, Y = [], [], [], []
        X_rf_daily, Y_rf_daily = [], []
        for i in range(lookback_window, len(daily_dates)):
            X_daily_dates.append(daily_dates[i - lookback_window: i])
            X.append(daily_crsp_tensor[i - lookback_window: i, :])

            Y_daily_dates.append(daily_dates[i])
            Y.append(daily_crsp_tensor[i, :])

            X_rf_daily.append(rf_daily[i - lookback_window: i])
            Y_rf_daily.append(rf_daily[i])

        return (
            np.array(X_daily_dates),
            np.array(X),
            np.array(Y_daily_dates),
            np.array(Y),
            np.array(X_rf_daily),
            np.array(Y_rf_daily),
        )

    def load(self, lookback_window: int):
        self._lookback_window_check(lookback_window)

        raw_data = self.load_raw()
        #################
        # Train Data    #
        #################
        train_X_daily_dates, train_X, train_Y_daily_dates, train_Y, train_X_rf_daily, train_Y_rf_daily = self._load(
            lookback_window=lookback_window,
            daily_dates=raw_data['train_daily_dates'],
            daily_crsp_tensor=raw_data['train_daily_crsp_tensor'],
            rf_daily=raw_data['train_rf_daily'],
        )

        #################
        # Val Data      #
        #################
        val_X_daily_dates, val_X, val_Y_daily_dates, val_Y, val_X_rf_daily, val_Y_rf_daily = self._load(
            lookback_window=lookback_window,
            daily_dates=raw_data['val_daily_dates'],
            daily_crsp_tensor=raw_data['val_daily_crsp_tensor'],
            rf_daily=raw_data['val_rf_daily'],
        )

        #################
        # Test Data     #
        #################
        test_X_daily_dates, test_X, test_Y_daily_dates, test_Y, test_X_rf_daily, test_Y_rf_daily = self._load(
            lookback_window=lookback_window,
            daily_dates=raw_data['test_daily_dates'],
            daily_crsp_tensor=raw_data['test_daily_crsp_tensor'],
            rf_daily=raw_data['test_rf_daily'],
        )

        return {
            "train_X_daily_dates"    : train_X_daily_dates,
            "train_X"                : train_X,
            "train_Y_daily_dates"    : train_Y_daily_dates,
            "train_Y"                : train_Y,
            "train_X_rf_daily"       : train_X_rf_daily,
            "train_Y_rf_daily"       : train_Y_rf_daily,

            "val_X_daily_dates"      : val_X_daily_dates,
            "val_X"                  : val_X,
            "val_Y_daily_dates"      : val_Y_daily_dates,
            "val_Y"                  : val_Y,
            "val_X_rf_daily"         : val_X_rf_daily,
            "val_Y_rf_daily"         : val_Y_rf_daily,

            "test_X_daily_dates"     : test_X_daily_dates,
            "test_X"                 : test_X,
            "test_Y_daily_dates"     : test_Y_daily_dates,
            "test_Y"                 : test_Y,
            "test_X_rf_daily"        : test_X_rf_daily,
            "test_Y_rf_daily"        : test_Y_rf_daily,
        }


    #######################################################################################################################
    #######################################################################################################################

    def _lookback_window_check(self, lookback_window: int):
        assert lookback_window > 0, f"Lookback window {lookback_window} must be greater than 0"
        assert lookback_window < 1 + self.train_end_idx - self.train_beg_idx,\
            f"Lookback window {lookback_window} must be greater than train window {1 + self.train_end_idx - self.train_beg_idx}"
        assert lookback_window < 1 + self.val_end_idx - self.val_beg_idx,\
            f"Lookback window {lookback_window} must be greater than val window {1 + self.val_end_idx - self.val_beg_idx}"
        assert lookback_window < 1 + self.test_end_idx - self.test_beg_idx,\
            f"Lookback window {lookback_window} must be greater than test window {1 + self.test_end_idx - self.test_beg_idx}"

    def _sanity_check(self):
        # sanity check
        assert self.train_beg_date < self.train_end_date, f"Train end date {self.train_end_date} is before train beg date {self.train_beg_date}"
        assert self.val_beg_date < self.val_end_date, f"Val end date {self.val_end_date} is before val beg date {self.val_beg_date}"
        assert self.test_beg_date < self.test_end_date, f"Test end date {self.test_end_date} is before test beg date {self.test_beg_date}"
        # check if there are overlaps
        assert self.train_end_date <= self.val_beg_date, f"Train end date {self.train_end_date} is after val beg date {self.val_beg_date}"
        assert self.val_end_date <= self.test_beg_date, f"Val end date {self.val_end_date} is after test beg date {self.test_beg_date}"
 
    def _date_check(self):
        # check if the dates are within the data
        assert self.train_beg_date >= self.dates[0], f"Train beg date {self.train_beg_date} is before daily data beg date {self.dates[0]}"
        assert self.train_end_date <= self.dates[-1], f"Train end date {self.train_end_date} is after daily data end date {self.dates[-1]}"
        assert self.val_beg_date >= self.dates[0], f"Val beg date {self.val_beg_date} is before daily data beg date {self.dates[0]}"
        assert self.val_end_date <= self.dates[-1], f"Val end date {self.val_end_date} is after daily data end date {self.dates[-1]}"
        assert self.test_beg_date >= self.dates[0], f"Test beg date {self.test_beg_date} is before daily data beg date {self.dates[0]}"
        assert self.test_end_date <= self.dates[-1], f"Test end date {self.test_end_date} is after daily data end date {self.dates[-1]}"