#  Copyright (c) 2024, Salesforce, Inc.
#  SPDX-License-Identifier: Apache-2
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.

# import sys, importlib.abc, traceback
# class Dbg(importlib.abc.MetaPathFinder):
#     def find_spec(self, fullname, path=None, target=None):
#         # fullname is the dotted module path being imported
#         if fullname.startswith("uni2ts"):
#             print(f"\n=== import '{fullname}' requested ===")
#             traceback.print_stack(limit=6)          # show who asked
#         return None      # allow the normal import machinery to keep looking

# # put our finder **before** the default ones
# sys.meta_path.insert(0, Dbg())
# sys.meta_path.insert(0, importlib.machinery.PathFinder())  # keep default finder
# print("[dbg] Dbg hook installed")
# print("[dbg] uni2ts modules already present:", [k for k in sys.modules if k.startswith("uni2ts")])

from functools import partial
from typing import Callable, Optional

import hydra
import lightning as L
import torch
from hydra.utils import instantiate
from omegaconf import DictConfig
from torch.utils._pytree import tree_map
from torch.utils.data import Dataset, DistributedSampler
from aurora.utils.data.curriculum_sampler import DistributedCurriculumSampler

from aurora.common import hydra_util  # noqa: hydra resolvers
from aurora.common.env import env
from aurora.data.loader import DataLoader
import os
import json





class DataModule(L.LightningDataModule):
    def __init__(
        self,
        cfg: DictConfig,
        train_dataset: Dataset,
        val_dataset: Optional[Dataset | list[Dataset]],
    ):
        super().__init__()
        self.cfg = cfg
        self.train_dataset = train_dataset

        if val_dataset is not None:
            self.val_dataset = val_dataset
            self.val_dataloader = self._val_dataloader

    @staticmethod
    def get_dataloader(
        dataset: Dataset,
        dataloader_func: Callable[..., DataLoader],
        shuffle: bool,
        world_size: int,
        batch_size: int,
        num_batches_per_epoch: Optional[int] = None,
        global_rank: int = None,
    ) -> DataLoader:

        cache_file = os.path.join(env.CURRICULUM_CACHE_PATH, "dataset__bins_12.0-15.7_15.7-18.3_18.3-22.0_weight_1.0.json")
        """
        cached_data = {
            'bin_indices': bin_indices,
            'bin_weights': bin_weights,
            'metadata': {
                'dataset_size': self.num_samples,
                'magnitude_bins': self.magnitude_bins,
                'error_weight_factor': self.error_weight_factor,
                'computation_time': elapsed_time
            }
        }
        """
        if os.path.exists(cache_file):
            print(f"Loading pre-computed curriculum binning from {cache_file}")
            with open(cache_file, 'r') as f:
                cached_data = json.load(f)
                bin_indices = cached_data['bin_indices']
                bin_weights = cached_data['bin_weights']
                num_samples = cached_data['metadata']['dataset_size']
        
        sampler = (
            DistributedCurriculumSampler(
                bin_indices,
                bin_weights,
                num_samples,
                num_replicas=world_size,
                rank=global_rank,
                shuffle=shuffle,
                seed=0,
                curriculum_phase=0,
                drop_last=False,
            )
            if world_size > 1
            else None
        )
        return dataloader_func(
            dataset=dataset,
            shuffle=False if sampler else shuffle,
            sampler=sampler,
            batch_size=batch_size,
            num_batches_per_epoch=num_batches_per_epoch,
        )

    def train_dataloader(self) -> DataLoader:
        return self.get_dataloader(
            self.train_dataset,
            instantiate(self.cfg.train_dataloader, _partial_=True),
            self.cfg.train_dataloader.shuffle,
            self.trainer.world_size,
            self.train_batch_size,
            num_batches_per_epoch=self.train_num_batches_per_epoch,
            global_rank=self.trainer.global_rank,
        )

    def _val_dataloader(self) -> DataLoader | list[DataLoader]:
        return tree_map(
            partial(
                self.get_dataloader,
                dataloader_func=instantiate(self.cfg.val_dataloader, _partial_=True),
                shuffle=self.cfg.val_dataloader.shuffle,
                world_size=self.trainer.world_size,
                batch_size=self.val_batch_size,
                num_batches_per_epoch=None,
            ),
            self.val_dataset,
        )

    @property
    def train_batch_size(self) -> int:
        return self.cfg.train_dataloader.batch_size // (
            self.trainer.world_size * self.trainer.accumulate_grad_batches
        )

    @property
    def val_batch_size(self) -> int:
        return self.cfg.val_dataloader.batch_size // (
            self.trainer.world_size * self.trainer.accumulate_grad_batches
        )

    @property
    def train_num_batches_per_epoch(self) -> int:
        return (
            self.cfg.train_dataloader.num_batches_per_epoch
            * self.trainer.accumulate_grad_batches
        )


@hydra.main(version_base="1.3", config_name="default.yaml")
def main(cfg: DictConfig):
    if cfg.tf32:
        assert cfg.trainer.precision == 32
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True

    model: L.LightningModule = instantiate(cfg.model, _convert_="all")

    if cfg.compile:
        model.module.compile(mode=cfg.compile)
    trainer: L.Trainer = instantiate(cfg.trainer)
    train_dataset: Dataset = instantiate(cfg.data).load_dataset(
        model.train_transform_map
    )
    # print(f"cfg: {cfg}")
    # exit(0)

    val_dataset: Optional[Dataset | list[Dataset]] = (
        tree_map(
            lambda ds: ds.load_dataset(model.val_transform_map),
            instantiate(cfg.val_data, _convert_="all"),
        )
        if "val_data" in cfg
        else None
    )
    L.seed_everything(cfg.seed, workers=True)
    trainer.fit(
        model,
        datamodule=DataModule(cfg, train_dataset, val_dataset),
        ckpt_path=cfg.ckpt_path
    )


if __name__ == "__main__":
    main()
