# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import os
from typing import Optional, List, Union

import torch
import transformers
from torch.utils.data import Dataset, Sampler, DataLoader
from transformers.trainer import (
    ALL_LAYERNORM_LAYERS,
    TRAINER_STATE_NAME,
    TrainerState,
    get_last_checkpoint,
    get_parameter_names,
    is_sagemaker_mp_enabled,
)
import random
import logging

logger = logging.getLogger(__name__)

class StratifiedBatchSampler(Sampler):
    def __init__(self, data_source: Dataset, seed=0):
        self.data_source = data_source
        self.task_to_indices = self.data_source.task_to_indices
        self.task_ids = list(self.task_to_indices.keys())
        self.batch_size = self.data_source.batch_size

        print(f"Task IDs: {self.task_ids}")
        print(f"Batch size: {self.batch_size}")
        self.samples_per_task = self.batch_size // len(self.task_to_indices.keys())
        self.seed = seed
        self.all_indices = [idx for indices in self.task_to_indices.values() for idx in indices]

    def __iter__(self):
        random.seed(self.seed)
        task_pools = {
            task: random.sample(indices, len(indices))
            for task, indices in self.task_to_indices.items()
        }

        while True:
            batch = []
            # Sample at least `samples_per_task` from each task
            for task in self.task_ids:
                pool = task_pools[task]
                if len(pool) < self.samples_per_task:
                    task_pools[task] = random.sample(self.all_indices, len(self.all_indices))
                    pool = task_pools[task]
                    if len(pool) < self.samples_per_task:
                        raise ValueError(f"Not enough samples in task {task} to satisfy `samples_per_task`.")
                batch.extend([pool.pop() for _ in range(self.samples_per_task)])

            # Randomly sample the remaining samples to fill the batch
            remaining_samples = self.batch_size - len(batch)
            if remaining_samples > 0:
                batch.extend(random.sample(self.all_indices, remaining_samples))

            yield batch

    def __len__(self):
        min_size = min(len(indices) for indices in self.task_to_indices.values())
        print(f"Min size: {min_size}")
        print(f"Samples per task: {self.samples_per_task}")
        return min_size // self.samples_per_task

class NeedleFocusedBatchSampler(Sampler):
    def __init__(self, data_source: Dataset, seed=0):
        self.data_source = data_source
        self.task_to_indices = self.data_source.task_to_indices
        self.instruction_to_idx = self.data_source.instruction_to_idx
        self.batch_size = self.data_source.batch_size
        self.seed = seed
        self.epoch = 0

        self.needle_indices = []
        self.other_indices = []
        self.needle_tasks = []
        self.other_tasks = []

        for task, index in self.instruction_to_idx.items():
            if "needle" in task.lower():
                self.needle_indices.extend(self.task_to_indices[index])
                self.needle_tasks.append(task)
            else:
                self.other_indices.extend(self.task_to_indices[index])
                self.other_tasks.append(task)
        
        if not self.needle_indices:
            raise ValueError("StratifiedBatchSampler requires at least one 'needle' task, but none were found.")
        if not self.other_indices:
             raise ValueError("StratifiedBatchSampler requires at least one non-'needle' task, but none were found.")

        print(f"Stratified Sampler: Found {len(self.needle_tasks)} 'needle' tasks ({len(self.needle_indices)} samples) and {len(self.other_tasks)} 'other' tasks ({len(self.other_indices)} samples).")
        print(f"Needle tasks: {self.needle_tasks}")
        print(f"Other tasks: {self.other_tasks}")

        self.needle_samples_per_batch = self.batch_size // 2
        self.other_samples_per_batch = self.batch_size - self.needle_samples_per_batch

        if len(self.needle_indices) < self.needle_samples_per_batch:
            raise ValueError(
                f"Not enough 'needle' samples ({len(self.needle_indices)}) to fill 50% of the batch "
                f"({self.needle_samples_per_batch})."
            )
        if len(self.other_indices) < self.other_samples_per_batch:
             raise ValueError(
                f"Not enough 'other' task samples ({len(self.other_indices)}) to fill 50% of the batch "
                f"({self.other_samples_per_batch})."
            )


    def __iter__(self):
        random.seed(self.seed)
        
        # Initialize shuffled pools
        needle_pool = random.sample(self.needle_indices, len(self.needle_indices))
        other_pool = random.sample(self.other_indices, len(self.other_indices))

        while True:
            batch = []
            
            # Replenish pools if needed
            if len(needle_pool) < self.needle_samples_per_batch:
                # Check if we have enough samples *at all* before reshuffling
                if len(self.needle_indices) < self.needle_samples_per_batch:
                     raise ValueError(f"Warning: Not enough total 'needle' samples ({len(self.needle_indices)}) to continue providing {self.needle_samples_per_batch} per batch. Stopping iteration.")
                needle_pool = random.sample(self.needle_indices, len(self.needle_indices))

            if len(other_pool) < self.other_samples_per_batch:
                if len(self.other_indices) < self.other_samples_per_batch:
                    raise ValueError(f"Warning: Not enough total 'other' samples ({len(self.other_indices)}) to continue providing {self.other_samples_per_batch} per batch. Stopping iteration.")
                other_pool = random.sample(self.other_indices, len(self.other_indices))

            # Sample from needle pool
            batch.extend([needle_pool.pop() for _ in range(self.needle_samples_per_batch)])

            # Sample from other pool
            batch.extend([other_pool.pop() for _ in range(self.other_samples_per_batch)])
            
            # Shuffle the final batch
            random.shuffle(batch)

            yield batch

    def __len__(self):
        # The number of batches is limited by the group that runs out first
        num_needle_batches = len(self.needle_indices) // self.needle_samples_per_batch if self.needle_samples_per_batch > 0 else float('inf')
        num_other_batches = len(self.other_indices) // self.other_samples_per_batch if self.other_samples_per_batch > 0 else float('inf')
        
        # If either is infinite (because 0 samples required), length is determined by the other
        if num_needle_batches == float('inf'):
            return num_other_batches if num_other_batches != float('inf') else 0 # Avoid inf len
        if num_other_batches == float('inf'):
            return num_needle_batches # num_needle_batches cannot also be inf here
            
        return min(num_needle_batches, num_other_batches)
    
    def set_epoch(self, epoch):
        self.epoch = epoch
        if hasattr(self.data_source, "set_epoch"):
            # this is important for dataset
            self.data_source.set_epoch(epoch)


class BaseSampler(Sampler):
    """Sampler for dataset, which enables `set_epoch` for Dataset.
    `set_epoch` will be called by huggingface Trainer at the end of each epoch.
    `shuffle` is also supported for training set shuffling
    """

    def __init__(self, data_source: Dataset, shuffle: bool = False, seed: int = 0):
        self.data_source = data_source
        self.shuffle = shuffle
        self.seed = seed
        self.epoch = 0

    def __iter__(self):
        if self.shuffle:
            g = torch.Generator()
            g.manual_seed(self.seed + self.epoch)
            # must not add rank here, or randomization will be different for each rank
            return iter(torch.randperm(len(self.data_source), generator=g).tolist())
        return iter(range(len(self.data_source)))

    def set_epoch(self, epoch):
        self.epoch = epoch
        if hasattr(self.data_source, "set_epoch"):
            # this is important for dataset
            self.data_source.set_epoch(epoch)

    def __len__(self):
        return len(self.data_source)


class DualBrainTrainer(transformers.Trainer):
    def __init__(self, **kwargs):
        self.compute_dtype = kwargs.pop("compute_dtype")
        super().__init__(**kwargs)

    def _get_train_sampler(self):
        return BaseSampler(self.train_dataset, shuffle=True, seed=self.args.seed)

    def _get_eval_sampler(self, eval_dataset):
        return BaseSampler(eval_dataset, shuffle=False)

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        outputs = model(inputs)
        loss = outputs["loss"]
        return (loss, outputs) if return_outputs else loss

    def evaluate(
        self,
        eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
        ignore_keys: Optional[List[str]] = None,
        metric_key_prefix: str = "eval"
    ):
        eval_dataloader = self.get_eval_dataloader(eval_dataset)

        model = self.model
        model.eval()

        all_losses = []

        max_steps = 2

        with torch.no_grad():
            for step, inputs_batch in enumerate(eval_dataloader):
                prepared_inputs = self._prepare_inputs(inputs_batch)
                loss = self.compute_loss(model, prepared_inputs)

                if isinstance(loss, torch.Tensor):
                    # Handle potential multi-GPU case if loss isn't averaged by compute_loss
                    # self.compute_loss here directly returns model_output["loss"].
                    # If using DataParallel, this loss might be a tensor with one element per GPU.
                    if self.args.n_gpu > 1 and loss.ndim > 0: # If loss tensor is not scalar and >1 GPU
                        loss = loss.mean()
                    all_losses.append(loss.item())
                elif loss is not None: # Should be a tensor, but as a fallback
                    all_losses.append(float(loss))

                if step >= max_steps:
                    break
        
        avg_loss = 0.0
        if all_losses:
            avg_loss = sum(all_losses) / len(all_losses)
        else:
            # Use logger from the base Trainer class or standard logging
            logger.warning(
                f"No losses recorded during evaluation ({metric_key_prefix}). "
                "Check if eval_dataloader is empty or if all batches resulted in errors."
            )

        metrics = {f"{metric_key_prefix}_loss": avg_loss}
        
        # Log metrics using the Trainer's log method
        self.log(metrics)
        
        # The base Trainer's evaluate method also calls self.control = self.callback_handler.on_evaluate(...)
        # For a "very straightforward and simple" override focusing on loss, this can be omitted,
        # but for full callback compatibility, it could be added.

        return metrics

    def get_train_dataloader(self) -> DataLoader:
        """
        Returns the training [`~torch.utils.data.DataLoader`].

        Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed
        training if necessary) otherwise.

        Subclass and override this method if you want to inject some custom behavior.
        """
        if self.train_dataset is None:
            raise ValueError("Trainer: training requires a train_dataset.")

        train_dataset = self.train_dataset
        data_collator = self.data_collator
        # if is_datasets_available() and isinstance(train_dataset, Dataset):
        train_dataset = self._remove_unused_columns(train_dataset, description="training")
        # else:
        #     data_collator = self._get_collator_with_removed_columns(data_collator, description="training")

        dataloader_params = {
            # "batch_size": self._train_batch_size,
            "collate_fn": data_collator,
            "num_workers": self.args.dataloader_num_workers,
            "pin_memory": self.args.dataloader_pin_memory,
            "persistent_workers": self.args.dataloader_persistent_workers,
        }

        dataloader_params["batch_sampler"] = StratifiedBatchSampler(self.train_dataset, seed=self.args.seed)

        if not isinstance(train_dataset, torch.utils.data.IterableDataset):
            # dataloader_params["sampler"] = self._get_train_sampler()
            # dataloader_params["drop_last"] = self.args.dataloader_drop_last
            # dataloader_params["worker_init_fn"] = seed_worker
            dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor

        return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))

    def get_eval_dataloader(self, eval_dataset: Optional[Union[str, Dataset]] = None) -> DataLoader:
        """
        Returns the evaluation [`~torch.utils.data.DataLoader`].

        Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed
        training if necessary) otherwise.

        Subclass and override this method if you want to inject some custom behavior.
        """
        if eval_dataset is None:
            eval_dataset = self.eval_dataset

        data_collator = self.data_collator
        # if is_datasets_available() and isinstance(train_dataset, Dataset):
        eval_dataset = self._remove_unused_columns(eval_dataset, description="training")
        # else:
        #     data_collator = self._get_collator_with_removed_columns(data_collator, description="training")

        dataloader_params = {
            # "batch_size": self._train_batch_size,
            "collate_fn": data_collator,
            "num_workers": self.args.dataloader_num_workers,
            "pin_memory": self.args.dataloader_pin_memory,
            "persistent_workers": self.args.dataloader_persistent_workers,
        }

        dataloader_params["batch_sampler"] = StratifiedBatchSampler(self.eval_dataset, seed=self.args.seed)

        if not isinstance(eval_dataset, torch.utils.data.IterableDataset):
            # dataloader_params["sampler"] = self._get_train_sampler()
            # dataloader_params["drop_last"] = self.args.dataloader_drop_last
            # dataloader_params["worker_init_fn"] = seed_worker
            dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor

        return self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params))

    def create_optimizer(self):
        """
        Setup the optimizer.

        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
        Trainer's init through `optimizers`, or subclass and override this method in a subclass.
        """
        if is_sagemaker_mp_enabled():
            return super().create_optimizer()

        opt_model = self.model

        if self.optimizer is None:
            decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
            decay_parameters = [name for name in decay_parameters if "bias" not in name]
            optimizer_grouped_parameters = [
                {
                    "params": [
                        p
                        for n, p in opt_model.named_parameters()
                        if (n in decay_parameters and p.requires_grad)
                    ],
                    "weight_decay": self.args.weight_decay,
                },
                {
                    "params": [
                        p
                        for n, p in opt_model.named_parameters()
                        if (n not in decay_parameters and p.requires_grad)
                    ],
                    "weight_decay": 0.0,
                },
            ]

            optimizer_cls, optimizer_kwargs = transformers.Trainer.get_optimizer_cls_and_kwargs(
                self.args
            )
            self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)

        return self.optimizer

    def save_model(self, output_dir: Optional[str], _internal_call: bool):
        ## save tuned model separately
        if self.is_deepspeed_enabled:
            state_dict = self.accelerator.get_state_dict(self.deepspeed)
        else:
            state_dict = self.model.state_dict()

        if self.args.should_save:
            return self.model.save_pretrained(output_dir, state_dict=state_dict)

    def train(
        self,
        resume_from_checkpoint=None,
        trial=None,
        ignore_keys_for_eval=None,
        **kwargs,
    ):
        """Correctly set self.state from checkpoint so get_train_dataloader can read from it."""
        if resume_from_checkpoint is False:
            resume_from_checkpoint = None

        if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint:
            resume_from_checkpoint = get_last_checkpoint(self.args.output_dir)

        if resume_from_checkpoint is not None:
            # In case of repeating the find_executable_batch_size, set `self._train_batch_size` properly
            self.state = TrainerState.load_from_json(
                os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)
            )
        return super().train(resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
