import copy
import os
import random
import time
from functools import partial, wraps
from typing import Callable, List, Sequence
from datetime import datetime

import hydra
import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn as nn
import wandb
from hydra.utils import get_original_cwd
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.utilities import rank_zero_only, rank_zero_warn
from tqdm.auto import tqdm

import src.models.nn.utils as U
import src.utils as utils
import src.utils.train
from src.dataloaders import SequenceDataset  # TODO make registry
from src.tasks import decoders, encoders, tasks
from src.utils import registry
from src.utils.optim_groups import add_optimizer_hooks
from scripts.notebooks.true_loss_level.get_transition_probabilities import get_no_path_dependency_loss
from scripts.notebooks.true_loss_level.get_transition_probabilities import get_optimal_loss
log = src.utils.train.get_logger(__name__)

# Turn on TensorFloat32 (speeds up large model training substantially)
import torch.backends
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

OmegaConf.register_new_resolver('eval', eval)
OmegaConf.register_new_resolver('div_up', lambda x, y: (x + y - 1) // y)

# Lots of annoying hacks to get WandbLogger to continuously retry on failure
class DummyExperiment:
    """Dummy experiment."""

    def nop(self, *args, **kw):
        pass

    def __getattr__(self, _):
        return self.nop

    def __getitem__(self, idx) -> "DummyExperiment":
        # enables self.logger.experiment[0].add_image(...)
        return self

    def __setitem__(self, *args, **kwargs) -> None:
        pass


def rank_zero_experiment(fn: Callable) -> Callable:
    """Returns the real experiment on rank 0 and otherwise the DummyExperiment."""

    @wraps(fn)
    def experiment(self):
        @rank_zero_only
        def get_experiment():
            return fn(self)

        return get_experiment() or DummyExperiment()

    return experiment


class CustomWandbLogger(WandbLogger):

    def __init__(self, *args, **kwargs):
        """Modified logger that insists on a wandb.init() call and catches wandb's error if thrown."""

        super().__init__(*args, **kwargs)

    @property
    @rank_zero_experiment
    def experiment(self):
        r"""
        Actual wandb object. To use wandb features in your
        :class:`~pytorch_lightning.core.lightning.LightningModule` do the following.
        Example::
        .. code-block:: python
            self.logger.experiment.some_wandb_function()
        """
        if self._experiment is None:
            if self._offline:
                os.environ["WANDB_MODE"] = "dryrun"
         
            attach_id = getattr(self, "_attach_id", None)
            if wandb.run is not None:
                # wandb process already created in this instance
                rank_zero_warn(
                    "There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse"
                    " this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`."
                )
                self._experiment = wandb.run
            elif attach_id is not None and hasattr(wandb, "_attach"):
                # attach to wandb process referenced
                self._experiment = wandb._attach(attach_id)
            else:
                # create new wandb process
                while True:
                    try:
                        self._wandb_init["id"] = wandb.util.generate_id()
                        self._wandb_init["entity"] = "timeseries-synthetics"

                        self._experiment = wandb.init(**self._wandb_init)
                        break
                    except Exception as e:
                        print("wandb Exception:\n", e)
                        t = random.randint(30, 60)
                        print(f"Sleeping for {t} seconds")
                        time.sleep(t)

                # define default x-axis
                if getattr(self._experiment, "define_metric", None):
                    self._experiment.define_metric("trainer/global_step")
                    self._experiment.define_metric("*", step_metric="trainer/global_step", step_sync=True)

        return self._experiment


class SequenceLightningModule(pl.LightningModule):
    def __init__(self, config):
        # Disable profiling executor. This reduces memory and increases speed.
        try:
            torch._C._jit_set_profiling_executor(False)
            torch._C._jit_set_profiling_mode(False)
        except AttributeError:
            pass

        super().__init__()
        # Passing in config expands it one level, so can access by self.hparams.train instead of self.hparams.config.train
        self.save_hyperparameters(config, logger=False)

        # Dataset arguments

        if self.hparams.dataset._name_ == "amos": 
            from src.dataloaders.dataloader_amos import AMOS_Dataset
            self.dataset = AMOS_Dataset(**self.hparams.dataset)
        elif self.hparams.dataset._name_ == "timeseries_synthetics": 
            from src.dataloaders.dataloader_mortgage import MortgageDataset
            self.dataset = MortgageDataset(**self.hparams.dataset)
        elif self.hparams.dataset._name_ == "timeseries_etth": 
            from src.dataloaders.dataloader_etth import ETTDataset
            self.dataset = ETTDataset(**self.hparams.dataset)
        elif self.hparams.dataset._name_ == "corelogic_loan_dataset":
            from src.dataloaders.dataloader_corelogic import LoanDataset
            self.dataset = LoanDataset(**self.hparams.dataset)
        elif self.hparams.dataset._name_ == "equities_dataset":
            from src.dataloaders.dataloader_equities import EquityDataset
            self.dataset = EquityDataset(**self.hparams.dataset)
        else:
            breakpoint()
            self.dataset = SequenceDataset.registry[self.hparams.dataset._name_](
            **self.hparams.dataset
        )

        # Check hparams
        self._check_config()

        # PL has some bugs, so add hooks and make sure they're only called once
        self._has_setup = False

        self.setup()  ## Added by KS

    def setup(self, stage=None):
        if not self.hparams.train.disable_dataset:
            self.hparams.train.disable_dataset = True #EE
            self.dataset.setup() 
        

        # We need to set up the model in setup() because for some reason when training with DDP, one GPU uses much more memory than the others
        # In order to not overwrite the model multiple times during different stages, we need this hack
        # TODO PL 1.5 seems to have an option to skip hooks to avoid this
        # https://github.com/PyTorchLightning/pytorch-lightning/issues/5410#issuecomment-762257024
        if self._has_setup:
            return
        else:
            self._has_setup = True
        # Convenience feature: if model specifies encoder, combine it with main encoder
        encoder_cfg = utils.to_list(self.hparams.encoder) + utils.to_list(
            self.hparams.model.pop("encoder", None)
        )
        decoder_cfg = utils.to_list(
            self.hparams.model.pop("decoder", None)
        ) + utils.to_list(self.hparams.decoder)

        # Instantiate model
        self.model = utils.instantiate(registry.model, self.hparams.model)
        if (name := self.hparams.train.post_init_hook['_name_']) is not None:
            kwargs = self.hparams.train.post_init_hook.copy()
            del kwargs['_name_']
            for module in self.modules():
                if hasattr(module, name):
                    getattr(module, name)(**kwargs)

        # Instantiate the task
        self.task = utils.instantiate(
            tasks.registry, self.hparams.task, dataset=self.dataset, model=self.model
        )
        # Create encoders and decoders
        encoder = encoders.instantiate(
            encoder_cfg, dataset=self.dataset, model=self.model
        )
        decoder = decoders.instantiate(
            decoder_cfg, model=self.model, dataset=self.dataset
        )

        # Extract the modules so they show up in the top level parameter count
        self.encoder = U.PassthroughSequential(self.task.encoder, encoder)
        self.decoder = U.PassthroughSequential(decoder, self.task.decoder)
        self.loss = self.task.loss
        self.loss_val = self.task.loss
        if hasattr(self.task, 'loss_val'):
            self.loss_val = self.task.loss_val
        self.metrics = self.task.metrics
        self.train_torchmetrics = self.task.train_torchmetrics
        self.val_torchmetrics = self.task.val_torchmetrics
        self.test_torchmetrics = self.task.test_torchmetrics
        
        if self.dataset._name_ == "equities_dataset":
            try:
                wandb.log({"Median Market Cap (billion USD)": self.dataset.dataset_train.median_mcap_billions})
            except:
                print("Failed to log median market cap")
                pass

        #added EE Jul 8, 2023
        # Add Portfolio Sharpe Ratio metric only for equities dataset
        if hasattr(self.dataset, '_name_') and self.dataset._name_ == "equities_dataset":
            # Import the metric classes
            from src.tasks.metrics import PortfolioSharpeRatioMetric, MarketSharpeRatioMetric
            
            # Add only portfolio and market Sharpe metrics
            #self.train_torchmetrics["portfolio_sharpe"] = PortfolioSharpeRatioMetric()
            #self.train_torchmetrics["market_sharpe"] = MarketSharpeRatioMetric()
            
            #self.val_torchmetrics["portfolio_sharpe"] = PortfolioSharpeRatioMetric()
            #self.val_torchmetrics["market_sharpe"] = MarketSharpeRatioMetric()
            
            #self.test_torchmetrics["portfolio_sharpe"] = PortfolioSharpeRatioMetric()
            #self.test_torchmetrics["market_sharpe"] = MarketSharpeRatioMetric()
            
            print("Added Portfolio and Market Sharpe Ratio metrics for equities dataset")
        try:
            assert 1==2
            v = np.concatenate(
                (
                    self.dataset.dataset_train.unobserved_macro_variable,
                    self.dataset.dataset_val.unobserved_macro_variable,
                    self.dataset.dataset_test.unobserved_macro_variable
                 )
                )
            f = np.concatenate(
                (
                    self.dataset.dataset_train.observed_macro_variable,
                    self.dataset.dataset_val.observed_macro_variable,
                    self.dataset.dataset_test.observed_macro_variable
                )
                )
            x = np.linspace(1,len(v),len(v))

            loss_no_path_dep = get_no_path_dependency_loss(self.dataset.dataset_train, self.dataset.dataset_val)
            optimal_loss = get_optimal_loss(self.dataset.dataset_val)
            v_loss = np.array([loss_no_path_dep]*len(v))
            v_loss_optimal = np.array([optimal_loss]*len(v))
            wandb.log({"Optimal loss joint modeling" : wandb.plot.line_series(xs=[list(x)], ys=[list(v_loss_optimal)],
                        keys=["Log loss"],
                        title="Optimal loss joint modeling",
                        xname="Time")}
                        )
            wandb.log({"Optimal loss marginal modeling" : wandb.plot.line_series(xs=[list(x)], ys=[list(v_loss)],
                        keys=["Log loss"],
                        title="Optimal loss marginal modeling",
                        xname="Time")}
                        )
            wandb.log({"Residual Marginal vs Joint Loss" : wandb.plot.line_series(xs=[list(x)], ys=[list(v_loss-v_loss_optimal)],
                        keys=["Residual Log loss"],
                        title="Residual Marginal vs Joint Loss",
                        xname="Time")}
                        )
        
            wandb.log({"Hidden Macro Variables" : wandb.plot.line_series(xs=[list(x)], ys=[list(v)],
                        keys=["Unobserved Macro Variable (v)"],
                        title="Unobserved Macro Variable",
                        xname="Time")}
                        )
            wandb.log({"Visible Macro Variables" : wandb.plot.line_series(xs=[list(x)], ys=[list(f)],
                        keys=["Observed Macro Variable (f)"],
                        title="Observed Macro Variable",
                        xname="Time")}
                        )
            #self.log_dict({"Macro Variable (f)": torch.tensor(f),"Hidden Macro Variable (v)": torch.tensor(v)})
            print("logged macro variables")
        except:
            print("failed to log macro variables")
            pass
        

    def load_state_dict(self, state_dict, strict=True):
        if self.hparams.train.pretrained_model_state_hook['_name_'] is not None:
            model_state_hook = utils.instantiate(
                registry.model_state_hook,
                self.hparams.train.pretrained_model_state_hook.copy(),
                partial=True,
            )
            # Modify the checkpoint['state_dict'] inside model_state_hook e.g. to inflate 2D convs to 3D convs
            state_dict = model_state_hook(self.model, state_dict)

        print("Custom load_state_dict function is running.")

        # note, it needs to return something from the normal function we overrided
        return super().load_state_dict(state_dict, strict=strict)

    def _check_config(self):
        assert self.hparams.train.state.mode in [None, "none", "null", "reset", "bptt", "tbptt"]
        assert (
            (n := self.hparams.train.state.n_context) is None
            or isinstance(n, int)
            and n >= 0
        )
        assert (
            (n := self.hparams.train.state.n_context_eval) is None
            or isinstance(n, int)
            and n >= 0
        )

    def _initialize_state(self):
        """Called at model setup and start of epoch to completely reset state"""
        self._state = None
        self._memory_chunks = []

    def _reset_state(self, batch, device=None):
        """Called to construct default_state when necessary, e.g. during BPTT"""
        device = device or batch[0].device
        self._state = self.model.default_state(*batch[0].shape[:1], device=device)

    def _detach_state(self, state):
        if isinstance(state, torch.Tensor):
            return state.detach()
        elif isinstance(state, tuple):
            return tuple(self._detach_state(s) for s in state)
        elif isinstance(state, list):
            return [self._detach_state(s) for s in state]
        elif isinstance(state, dict):
            return {k: self._detach_state(v) for k, v in state.items()}
        elif state is None:
            return None
        else:
            raise NotImplementedError

    def _process_state(self, batch, batch_idx, train=True):
        """Handle logic for state context."""
        # Number of context steps
        key = "n_context" if train else "n_context_eval"
        n_context = self.hparams.train.state.get(key)

        # Don't need to do anything if 0 context steps. Make sure there is no state
        if n_context == 0 and self.hparams.train.state.mode not in ['tbptt']:
            self._initialize_state()
            return

        # Reset state if needed
        if self.hparams.train.state.mode == "reset":
            if batch_idx % (n_context + 1) == 0:
                self._reset_state(batch)

        # Pass through memory chunks
        elif self.hparams.train.state.mode == "bptt":
            self._reset_state(batch)
            with torch.no_grad():  # should be unnecessary because individual modules should handle this
                for _batch in self._memory_chunks:
                    self.forward(_batch)
            # Prepare for next step
            self._memory_chunks.append(batch)
            self._memory_chunks = self._memory_chunks[-n_context:]

        elif self.hparams.train.state.mode == 'tbptt':
            _, _, z = batch
            reset = z["reset"]
            if reset:
                self._reset_state(batch)
            else:
                self._state = self._detach_state(self._state)

    # def forward(self, batch):
    #     """Passes a batch through the encoder, backbone, and decoder"""
    #     # z holds arguments such as sequence length
    #     x, y, *z = batch # z holds extra dataloader info such as resolution
    #     if len(z) == 0:
    #         z = {}
    #     else:
    #         assert len(z) == 1 and isinstance(z[0], dict), "Dataloader must return dictionary of extra arguments"
    #         z = z[0]

    #     x, w = self.encoder(x, **z) # w can model-specific constructions such as key_padding_mask for transformers or state for RNNs
    #     x, state = self.model(x, **w, state=self._state)
    #     self._state = state
    #     x, w = self.decoder(x, state=state, **z)
    #     return x, y, w

    def forward(self, batch):
        return self.task.forward(batch, self.encoder, self.model, self.decoder, self._state)

    def step(self, x_t):
        x_t, *_ = self.encoder(x_t) # Potential edge case for encoders that expect (B, L, H)?
        x_t, state = self.model.step(x_t, state=self._state)
        self._state = state
        # x_t = x_t[:, None, ...] # Dummy length
        # x_t, *_ = self.decoder(x_t, state=state)
        # x_t = x_t[:, 0, ...]
        x_t, *_ = self.decoder.step(x_t, state=state)
        return x_t

    def _shared_step(self, batch, batch_idx, prefix="train"):
        
        self._process_state(batch, batch_idx, train=(prefix == "train"))
        x, y, w = self.forward(batch)
        # Loss
        if prefix == 'train':
            loss = self.loss(x, y, **w)
        else:
            loss = self.loss_val(x, y, **w)

        # Metrics
        metrics = self.metrics(x, y, **w)
        metrics["loss"] = loss
        metrics = {f"{prefix}/{k}": v for k, v in metrics.items()}

        # Calculate torchmetrics
        #if prefix!='train':
        #    breakpoint()
        torchmetrics = getattr(self, f'{prefix}_torchmetrics')
        torchmetrics(x, y, loss=loss)
        
        log_on_step = 'eval' in self.hparams and self.hparams.eval.get('log_on_step', False) and prefix == 'train'

        self.log_dict(
            metrics,
            on_step=log_on_step,
            on_epoch=True,
            prog_bar=True,
            add_dataloader_idx=False,
            sync_dist=True,
        )

        # log the whole dict, otherwise lightning takes the mean to reduce it
        # https://pytorch-lightning.readthedocs.io/en/stable/visualize/logging_advanced.html#enable-metrics-for-distributed-training
        self.log_dict(
            torchmetrics,
            on_step=log_on_step,
            on_epoch=True,
            prog_bar=True,
            add_dataloader_idx=False,
            sync_dist=True,
        )
        return loss

    def on_train_epoch_start(self):
        # Reset training torchmetrics
        self.task._reset_torchmetrics("train")

    def shift_train_set_forward(self):
        current_epoch = self.current_epoch
        if not hasattr(self.dataset, 'rolling_model'):
            return 
        if not hasattr(self.dataset, 'rolling_epoch_interval') or self.dataset.rolling_epoch_interval <= 0:
            return
        if current_epoch % self.dataset.rolling_epoch_interval ==0 and self.dataset._name_=="corelogic_loan_dataset" and current_epoch>self.dataset.rolling_start_epoch and self.dataset.rolling_model:
            print("Shifting train set forward")
            # Split the dataset into train, val, and test sets
            # Get current dates and shift forward 1 year
            vd_old = self.dataset.config["val_split_date"]
            val_date = datetime.strptime(self.dataset.config["val_split_date"], "%Y-%m")
            test_date = datetime.strptime(self.dataset.config["test_split_date"], "%Y-%m") 
            self.dataset.config["val_split_date"] = (val_date.replace(year=val_date.year + 1)).strftime("%Y-%m")
            self.dataset.config["test_split_date"] = (test_date.replace(year=test_date.year + 1)).strftime("%Y-%m")
            self.dataset._split_data()
            # Delete old datasets to free memory

            create_new_datasets = False
            if create_new_datasets:
                del self.dataset.dataset_train
                del self.dataset.dataset_val 
                del self.dataset.dataset_test
                import gc
                gc.collect()
                # Get new datasets
                self.dataset.dataset_train, self.dataset.dataset_val, self.dataset.dataset_test = self.dataset.get_data()
            else:
                self.dataset.dataset_train.lower_bound = self.dataset.limits_train[0]
                self.dataset.dataset_train.upper_bound = self.dataset.limits_train[1]
                self.dataset.dataset_val.lower_bound = self.dataset.limits_val[0]
                self.dataset.dataset_val.upper_bound = self.dataset.limits_val[1]
                self.dataset.dataset_test.lower_bound = self.dataset.limits_test[0]
                self.dataset.dataset_test.upper_bound = self.dataset.limits_test[1]
                self.dataset.dataset_train.lower_sampling_bound  = self.dataset.sampling_train[0]
                self.dataset.dataset_train.upper_sampling_bound  = self.dataset.sampling_train[1]
                self.dataset.dataset_val.lower_sampling_bound  = self.dataset.sampling_val[0]
                self.dataset.dataset_val.upper_sampling_bound  = self.dataset.sampling_val[1]
                self.dataset.dataset_test.lower_sampling_bound  = self.dataset.sampling_test[0]
                self.dataset.dataset_test.upper_sampling_bound  = self.dataset.sampling_test[1]
            vd = self.dataset.config["val_split_date"]
            td = self.dataset.config["test_split_date"]
            print(f"New val split date: {vd}")
            print(f"New test split date: {td}")
            
            # Save checkpoint
            now = datetime.now()
            day_directory = now.strftime("%Y-%m-%d")
            time_directory = now.strftime("%H-%M-%S")
            import os
            BASE_PATH = os.environ.get('BASE_PATH', '')
            base_path_outputs = os.path.join(BASE_PATH, "outputs/outputs")
            if not hasattr(self, '_checkpoint_time'):
                self._checkpoint_time = time_directory
            checkpoint_path = os.path.join(base_path_outputs, day_directory, f"shift_ckpt_epoch_{current_epoch}_{self._checkpoint_time}_val_{vd_old}.ckpt")
            self.trainer.save_checkpoint(checkpoint_path)
            print(f"Saved checkpoint to {checkpoint_path}")

        # self.dataset.dataset_val
        # self.dataset.dataset_test
        # self.dataset.dataset_train
        # Update lower and upper sampling bounds
        # Save checkpoint
    def compute_metrics(self):
        if self.dataset._name_=="equities_dataset":
            pass
        if self.dataset._name_ == "corelogic_loan_dataset" and (self.current_epoch in [0,1, self.trainer.max_epochs-1]):
            from scripts.notebooks.true_loss_level.get_corelogic import get_metrics_cl
            from scripts.notebooks.true_loss_level.get_corelogic import evaluate_model_cl
            model = self
            model._state = None
            test = self.dataset.dataset_test
            batch_size = 1
            all_units_in_batch_dimension = False
            tol = 1e-15
            name_val = "corelogic-test-set-seq-2025-04-25"
            # Evaluate on val+test together, It makes more sense to evaluate only on test
            probs_torch, y_true_torch = evaluate_model_cl(
                model=model,
                model_name="set-seq",
                val_set=test,
                test_set=None,   #
                batch_size=batch_size,
                fix_seed=True,
                all_units_in_batch_dim=all_units_in_batch_dimension
            )

            # Convert to NumPy and slice as needed
            probs = probs_torch[:, :, :-1, :].detach().cpu().numpy() + tol
            y_true = y_true_torch[:, :, 1:, :].detach().cpu().numpy()
            metrics_set_seq_val = get_metrics_cl(y_true, probs, name=name_val, save_heatmap=False)
            
            # Log metrics to wandb
            print("Logging metrics:", metrics_set_seq_val)
            for k, v in metrics_set_seq_val.items():
                wandb.log({k: v})

        if self.dataset._name_!="timeseries_synthetics":
            return
        if self.dataset.use_random_input_size:
            return
        if self.current_epoch not in [0,1, self.trainer.max_epochs-1]:
            return
        from scripts.notebooks.multi_class_auc import get_metrics, reshape_for_auc
        from scripts.notebooks.true_loss_level.get_transition_probabilities import evaluate_model
        import numpy as np
        
        # Set up model and dataset
        # Pass self instead of self.model to use the full pipeline
        model = self  # Use the full SequenceLightningModule instead of just self.model
        model._state = None  # Reset state
        dataset = self.dataset.dataset_test
        nr_samples_for_inference_vec = [1000]
        check_grads = False
        check_batching = False
        
        # Get model predictions
        print("Computing model predictions...")
        predicted_model_probabilities = evaluate_model(model, dataset, nr_samples_for_inference_vec, check_grads, check_batching)
        model_predictions = predicted_model_probabilities[0]
        model_predictions = np.stack(model_predictions, axis=0)
        model_predictions = model_predictions[:,:,:-1,:]
        # Convert to numpy and reshape
        true_transition_probabilities = np.array(dataset.transition_probabilities.cpu())
        true_transitions = np.array(dataset.Y.cpu())
        
        # Ensure we're working with the right dimensions
        # Model outputs are for t >= 1, but dataset arrays include t=0
        # We need to match the time dimension by excluding t=0
        true_transition_probabilities = true_transition_probabilities[:,:,1:,:]
        true_transitions = true_transitions[:,:,1:,:]
        
        assert model_predictions.shape == true_transition_probabilities.shape
        # Reshape for metric calculation using the same function as in multi_class_auc.py
        true_transition_probabilities_reshaped = reshape_for_auc(true_transition_probabilities)
        true_transitions_reshaped = reshape_for_auc(true_transitions)
 
        
        model_predictions_reshaped = reshape_for_auc(model_predictions)
        
        # Compute all metrics
        print("Computing metrics...")
        metrics_dict = get_metrics(true_transitions_reshaped, true_transition_probabilities_reshaped, [model_predictions_reshaped])
        # Memory usage:
        # Get max memory allocated (in MB)
        max_memory_bytes = torch.cuda.max_memory_allocated()
        max_memory_mb = max_memory_bytes / 1024 / 1024 / 1024

        # Extract the first (and only) values from each list in metrics_dict
        simplified_metrics = {
            'r2_score': metrics_dict['r2'][0],
            'correlation': metrics_dict['corr'][0],
            'rel_abs_error': metrics_dict['rel_abs_err'][0],
            'auc_rare_event': metrics_dict['auc'][0],
            'memory_gpu_usage_gb': max_memory_mb,
        }
        
        # Log metrics to wandb
        print("Logging metrics:", simplified_metrics)
        for k, v in simplified_metrics.items():
            wandb.log({k: v})
        
        return simplified_metrics
    
    def save_rolling_model(self):
        # Save the rolling model checkpoint if last epoch
        if self.dataset._name_=="equities_dataset" and self.current_epoch==self.trainer.max_epochs-1:
            # Save the model
            now = datetime.now()
            day_directory = "rolling_model"
            val_start_year = self.dataset.config["val_split_date"].split("-")[0]
            time_directory = "val_start_year_{}".format(val_start_year)
            BASE_PATH = os.environ.get('BASE_PATH', '')
            base_path_outputs = os.path.join(BASE_PATH, "outputs/equities_outputs")
            checkpoint_path = os.path.join(base_path_outputs, day_directory, f"_epoch_{self.current_epoch}.ckpt")
            self.trainer.save_checkpoint(checkpoint_path)
            print(f"Saved rolling model checkpoint to {checkpoint_path}")

    def training_epoch_end(self, outputs):
        # Log training torchmetrics
        super().training_epoch_end(outputs)
        self.shift_train_set_forward() #EE
        self.compute_metrics()
        #self.save_rolling_model()
        #
        # self.log_dict(
        #     {f"train/{k}": v for k, v in self.task.get_torchmetrics("train").items()},
        #     on_step=False,
        #     on_epoch=True,
        #     prog_bar=True,
        #     add_dataloader_idx=False,
        #     sync_dist=True,
        # )

    def on_validation_epoch_start(self):
        # Reset all validation torchmetrics
        for name in self.val_loader_names:
            self.task._reset_torchmetrics(name)

    def validation_epoch_end(self, outputs):
        # Log all validation torchmetrics
        super().validation_epoch_end(outputs)
        # for name in self.val_loader_names:
        #     self.log_dict(
        #         {f"{name}/{k}": v for k, v in self.task.get_torchmetrics(name).items()},
        #         on_step=False,
        #         on_epoch=True,
        #         prog_bar=True,
        #         add_dataloader_idx=False,
        #         sync_dist=True,
        #     )

    def on_test_epoch_start(self):
        # Reset all test torchmetrics
        for name in self.test_loader_names:
            self.task._reset_torchmetrics(name)

    def test_epoch_end(self, outputs):
        # Log all test torchmetrics
        super().test_epoch_end(outputs)


        #
        # for name in self.test_loader_names:
        #     self.log_dict(
        #         {f"{name}/{k}": v for k, v in self.task.get_torchmetrics(name).items()},
        #         on_step=False,
        #         on_epoch=True,
        #         prog_bar=True,
        #         add_dataloader_idx=False,
        #         sync_dist=True,
        #     )

    def training_step(self, batch, batch_idx, dataloader_idx=0):
        loss = self._shared_step(batch, batch_idx, prefix="train")

        # Log the loss explicitly so it shows up in WandB
        # Note that this currently runs into a bug in the progress bar with ddp (as of 1.4.6)
        # https://github.com/PyTorchLightning/pytorch-lightning/pull/9142
        # We additionally log the epochs under 'trainer' to get a consistent prefix with 'global_step'
        loss_epoch = {"trainer/loss": loss, "trainer/epoch": self.current_epoch}
        self.log_dict(
            loss_epoch,
            on_step=True,
            on_epoch=False,
            prog_bar=False,
            add_dataloader_idx=False,
            sync_dist=True,
        )

        # Log any extra info that the models want to expose (e.g. output norms)
        metrics = {}
        for module in list(self.modules())[1:]:
            if hasattr(module, "metrics"):
                metrics.update(module.metrics)

        self.log_dict(
            metrics,
            on_step=True,
            on_epoch=False,
            prog_bar=False,
            add_dataloader_idx=False,
            sync_dist=True,
        )

        return loss

    def validation_step(self, batch, batch_idx, dataloader_idx=0):
        ema = (
            self.val_loader_names[dataloader_idx].endswith("/ema")
            and self.optimizers().optimizer.stepped
        )  # There's a bit of an annoying edge case with the first (0-th) epoch; it has to be excluded due to the initial sanity check
        if ema:
            self.optimizers().swap_ema()
        loss = self._shared_step(
            batch, batch_idx, prefix=self.val_loader_names[dataloader_idx]
        )
        if ema:
            self.optimizers().swap_ema()

        return loss

    def test_step(self, batch, batch_idx, dataloader_idx=0):
        return self._shared_step(
            batch, batch_idx, prefix=self.test_loader_names[dataloader_idx]
        )

    def configure_optimizers(self):
        # Set zero weight decay for some params
        if 'optimizer_param_grouping' in self.hparams.train:
            add_optimizer_hooks(self.model, **self.hparams.train.optimizer_param_grouping)

        # Normal parameters
        all_params = list(self.parameters())
        params = [p for p in all_params if not hasattr(p, "_optim")]

        optimizer = utils.instantiate(registry.optimizer, self.hparams.optimizer, params)

        del self.hparams.optimizer._name_

        # Add parameters with special hyperparameters
        hps = [getattr(p, "_optim") for p in all_params if hasattr(p, "_optim")]
        hps = [
            # dict(s) for s in set(frozenset(hp.items()) for hp in hps)
            dict(s) for s in sorted(list(dict.fromkeys(frozenset(hp.items()) for hp in hps)))
            # dict(s) for s in dict.fromkeys(frozenset(hp.items()) for hp in hps)
        ]  # Unique dicts
        print("Hyperparameter groups", hps)
        for hp in hps:
            params = [p for p in all_params if getattr(p, "_optim", None) == hp]
            optimizer.add_param_group(
                {"params": params, **self.hparams.optimizer, **hp}
            )

        ### Layer Decay ###

        if self.hparams.train.layer_decay['_name_'] is not None:
            get_num_layer = utils.instantiate(
                registry.layer_decay,
                self.hparams.train.layer_decay['_name_'],
                partial=True,
            )

            # Go through all parameters and get num layer
            layer_wise_groups = {}
            num_max_layers = 0
            for name, p in self.named_parameters():
                # Get layer id for each parameter in the model
                layer_id = get_num_layer(name)

                # Add to layer wise group
                if layer_id not in layer_wise_groups:
                    layer_wise_groups[layer_id] = {
                        'params': [],
                        'lr': None,
                        'weight_decay': self.hparams.optimizer.weight_decay
                    }
                layer_wise_groups[layer_id]['params'].append(p)

                if layer_id > num_max_layers: num_max_layers = layer_id

            # Update lr for each layer
            for layer_id, group in layer_wise_groups.items():
                group['lr'] = self.hparams.optimizer.lr * (self.hparams.train.layer_decay.decay ** (num_max_layers - layer_id))

            # Reset the torch optimizer's param groups
            optimizer.param_groups = []
            for layer_id, group in layer_wise_groups.items():
                optimizer.add_param_group(group)

        # Print optimizer info for debugging
        keys = set([k for hp in hps for k in hp.keys()])  # Special hparams
        utils.train.log_optimizer(log, optimizer, keys)
        # Configure scheduler
        if "scheduler" not in self.hparams:
            return optimizer
        lr_scheduler = utils.instantiate(
            registry.scheduler, self.hparams.scheduler, optimizer
        )
        scheduler = {
            "scheduler": lr_scheduler,
            "interval": self.hparams.train.interval,  # 'epoch' or 'step'
            "monitor": self.hparams.train.monitor,
            "name": "trainer/lr",  # default is e.g. 'lr-AdamW'
        }
        # See documentation for how to configure the return
        # https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.core.lightning.html#pytorch_lightning.core.lightning.LightningModule.configure_optimizers
        return [optimizer], [scheduler]

    def train_dataloader(self):
        return self.dataset.train_dataloader(**self.hparams.loader)

    def _eval_dataloaders_names(self, loaders, prefix):
        """Process loaders into a list of names and loaders"""
        if utils.is_dict(loaders):
            return [
                f"{prefix}/{k}" if k is not None else prefix for k in loaders.keys()
            ], list(loaders.values())
        elif utils.is_list(loaders):
            return [f"{prefix}/{i}" for i in range(len(loaders))], loaders
        else:
            return [prefix], [loaders]

    def _eval_dataloaders(self):
        # Return all val + test loaders
        val_loaders = self.dataset.val_dataloader(**self.hparams.loader)
        test_loaders = self.dataset.test_dataloader(**self.hparams.loader)
        val_loader_names, val_loaders = self._eval_dataloaders_names(val_loaders, "val")
        test_loader_names, test_loaders = self._eval_dataloaders_names(
            test_loaders, "test"
        )

        # Duplicate datasets for ema
        if self.hparams.train.ema > 0.0:
            val_loader_names += [name + "/ema" for name in val_loader_names]
            val_loaders = val_loaders + val_loaders
            test_loader_names += [name + "/ema" for name in test_loader_names]
            test_loaders = test_loaders + test_loaders

        # adding option to only have val loader at eval (eg if test is duplicate)
        if self.hparams.train.get("remove_test_loader_in_eval", None) is not None:
            return val_loader_names, val_loaders
        # default behavior is to add test loaders in eval
        else:
            return val_loader_names + test_loader_names, val_loaders + test_loaders

    def val_dataloader(self):
        val_loader_names, val_loaders = self._eval_dataloaders()
        self.val_loader_names = val_loader_names
        return val_loaders

    def test_dataloader(self):
        test_loader_names, test_loaders = self._eval_dataloaders()
        self.test_loader_names = ["final/" + name for name in test_loader_names]
        return test_loaders


### pytorch-lightning utils and entrypoint ###

def create_trainer(config, **kwargs):
    callbacks: List[pl.Callback] = []
    logger = None

    # WandB Logging
    if config.get("wandb") is not None:
        # Pass in wandb.init(config=) argument to get the nice 'x.y.0.z' hparams logged
        # Can pass in config_exclude_keys='wandb' to remove certain groups
        import wandb
        logger = CustomWandbLogger(
            config=utils.to_dict(config, recursive=True),
            settings=wandb.Settings(start_method="fork"),
            **config.wandb,
        )

    # Lightning callbacks
    if "callbacks" in config:
        for _name_, callback in config.callbacks.items():
            if config.get("wandb") is None and _name_ in ["learning_rate_monitor"]:
                continue
            log.info(f"Instantiating callback <{registry.callbacks[_name_]}>")
            callback._name_ = _name_
            callbacks.append(utils.instantiate(registry.callbacks, callback))

    # Add ProgressiveResizing callback
    if config.callbacks.get("progressive_resizing", None) is not None:
        num_stages = len(config.callbacks.progressive_resizing.stage_params)
        print(f"Progressive Resizing: {num_stages} stages")
        for i, e in enumerate(config.callbacks.progressive_resizing.stage_params):
            # Stage params are resolution and epochs, pretty print
            print(f"\tStage {i}: {e['resolution']} @ {e['epochs']} epochs")
    
    # April 11, 2024 update
    # Get the current date and time
    now = datetime.now()
    day_directory = now.strftime("%Y-%m-%d")
    if config.dataset._name_=="equities_dataset":
        time_directory = str(config.dataset.dataset_config.val_beg_date)[:4]
        if config.model.layer._name_=="attention_factors":
            time_directory = time_directory + "_factors_{}".format(config.model.layer.n_factors)
            if not config.model.layer.use_factor_portfolio:
                time_directory = time_directory + "_no_factors_portfolio"
            print("time_directory", time_directory)
        elif config.encoder.debug:
            time_directory = time_directory +"_"+ config.model.layer._name_
            
            #if config.model.layer._name_=="long-conv":
            #    time_directory = time_directory + "_set_{}".format(config.model.layer.set_mixing_architecture)
        elif config.model.layer._name_=="long-conv":
            time_directory = time_directory + "_set_{}".format(config.model.layer.set_mixing_architecture)
        
        if config.task.loss == "next_step_sharpe_ratio_with_transaction_cost":
            time_directory = time_directory + "_transaction_cost"
        time_directory = time_directory + "_seed_{}".format(config.train.seed)
    elif config.dataset._name_ == "corelogic_dataset":
        time_directory = str(config.dataset.dataset_config.val_beg_date)[:4]
        time_directory += "_corelogic"
        if config.model.layer._name_=="attention_factors":
            time_directory = time_directory + "_factors_{}".format(config.model.layer.n_factors)
            print("time_directory", time_directory)
        elif config.encoder.debug:
            time_directory = time_directory +"_"+ config.model.layer._name_
            #if config.model.layer._name_=="long-conv":
            #    time_directory = time_directory + "_set_{}".format(config.model.layer.set_mixing_architecture)
        elif config.model.layer._name_=="long-conv":
            time_directory = time_directory + "_set_{}".format(config.model.layer.set_mixing_architecture)

    else:
        time_directory = now.strftime("%H-%M-%S")
    print("time_directory", time_directory)
    # Set the base path
    base_path_outputs = "../../outputs"

    # Combine to form the full directory path
    full_dir_path = os.path.join(base_path_outputs, day_directory, time_directory)

    # Check if the final directory exists and create it if not
    if not os.path.exists(full_dir_path):
        os.makedirs(full_dir_path)

    # Assuming callbacks[-1] is your ModelCheckpoint callback
    callbacks[-1].dirpath = full_dir_path

    # Configure ddp automatically
    n_devices = config.trainer.get('devices', 1)
    if isinstance(n_devices, Sequence):  # trainer.devices could be [1, 3] for example
        n_devices = len(n_devices)
    if n_devices > 1 and config.trainer.get('strategy', None) is None:
        config.trainer.strategy = dict(
            _target_='pytorch_lightning.strategies.DDPStrategy',
            find_unused_parameters=False,
            gradient_as_bucket_view=True,  # https://pytorch-lightning.readthedocs.io/en/stable/advanced/advanced_gpu.html#ddp-optimizations
        )

    # Init lightning trainer
    log.info(f"Instantiating trainer <{config.trainer._target_}>")
    trainer = hydra.utils.instantiate(
        config.trainer, callbacks=callbacks, logger=logger)

    return trainer


def train(config):
    if config.train.seed is not None:
        pl.seed_everything(config.train.seed, workers=True)
    trainer = create_trainer(config)
    model = SequenceLightningModule(config)
    #Test
    #for name, param in model.model.layers[0].named_parameters():
    #    print(name)
    #    print(param.size())
    # Run initial validation epoch (useful for debugging, finetuning)
    if config.train.validate_at_start:
        print("Running validation before training")
        trainer.validate(model)

    if config.train.ckpt is not None:
        trainer.fit(model, ckpt_path=config.train.ckpt)
    else:
        trainer.fit(model)
    if config.train.test:
        trainer.test(model)




@hydra.main(config_path="configs", config_name="config.yaml")
def main(config: OmegaConf):

    # Process config:
    # - register evaluation resolver
    # - filter out keys used only for interpolation
    # - optional hooks, including disabling python warnings or debug friendly configuration
    try:
        lvl = config["dataset"]["generator"]["level"]
        lst = ["veasy", "supereasy", "supereasy_2d","supereasy_2d_long_lookback", "2d_path_dependency","supereasy_2d_no_loan_specific_feature","supereasy_1d", "2d_with_stochasticity"]
        config["dataset"]["num_states"] = 3 if (lvl in lst) else 10
        config["dataset"]["num_terminal_states"] = 1 if (lvl in lst) else 2
        # May 16, 2024
        if config["decoder"]["forecast"]:
            config["scheduler"]["num_training_steps"] *=70
            print("forecasting objective")
        # 
    except:
        print("Could not load different levels")




    config = utils.train.process_config(config)

    # Pretty print config using Rich library
    utils.train.print_config(config, resolve=True)


    import cProfile
    import pstats

    # Create a profiler
    profiler = cProfile.Profile()
    profiler.enable()
    
    # Run your main function
    train(config)  
    
    profiler.disable()
    
    # Sort and display the stats
    stats = pstats.Stats(profiler)
    stats.sort_stats(pstats.SortKey.CUMULATIVE)



    print_runtime_statistics = False
    if print_runtime_statistics:
        import os
        import sys
        BASE_PATH = os.environ.get("BASE_PATH", ".")

        # 3. Print top 50 operations
        print("\n=== Top 50 operations ===\n")
        stats.sort_stats("cumtime").print_stats(BASE_PATH, 25)

        # 1. Print callees for long_conv.py:189
        print("\n=== Callees within long_conv.py:189 ===\n")
        stats.print_callees("long_conv.py:189", 25)

        # 2. Print top 10 low-level operations like matmul, conv, etc.
        print("\n=== Top 10 operations involving matmul, conv, etc. ===\n")
        stats.sort_stats("cumtime").print_stats("matmul|conv|einsum|ifft|rfft|fft", 10)

    
        # Display stats for `SetEncoder.forward` only
        stats.print_stats(40)
        stats.print_stats('forward')
        stats.print_stats('SetEncoder.forward',50)

if __name__ == "__main__":
    main()
