import math
from typing import Any, Dict, List, Optional, Tuple, Union
from configs import ScriptArguments
from transformers import PreTrainedModel, GenerationConfig, LogitsProcessor
from trl.models import unwrap_model_for_generation
from trl.data_utils import is_conversational, apply_chat_template
from trl.trainer.utils import pad
from accelerate.utils import broadcast_object_list, gather_object
import copy

import torch
from torch import nn

class LogitsProcessorWithLossMask(LogitsProcessor):
    def __init__(self, inputs: Dict[str, torch.Tensor], num_beams: int):
        super().__init__()
        self._num_beams = num_beams
        self.loss_mask = inputs['loss_mask']
        self.labels = inputs['labels']
        self.idx_offset = inputs['input_ids'].shape[-1]

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        mask = torch.full_like(scores, -math.inf)
        for batch_id, beam_sent in enumerate(input_ids.view(-1, self._num_beams, input_ids.shape[-1])):
            for beam_id, sent in enumerate(beam_sent):
                # Get the current position in sequence
                curr_pos = len(sent) - self.idx_offset
                batch_beam_idx = batch_id * self._num_beams + beam_id

                # If loss_mask is 0, force the logits to select the label token
                if not self.loss_mask[batch_id][curr_pos]:
                    # Set all logits to negative infinity
                    scores[batch_beam_idx] = mask[batch_beam_idx]
                    # Set only the label token's logit to a high value
                    label_token = self.labels[batch_id][curr_pos].item()
                    scores[batch_beam_idx, label_token] = 0

                # If loss_mask is 1, keep original logits (no modification needed)

        return scores

def trainer_class_factory(args: ScriptArguments):
    if args.use_unsloth:
        from unsloth_compiled_cache.UnslothSFTTrainer import UnslothSFTTrainer as SFTTrainer
        from unsloth_compiled_cache.UnslothGRPOTrainer import UnslothGRPOTrainer as GRPOTrainer
        from unsloth import FastLanguageModel
    else:
        from trl import SFTTrainer, GRPOTrainer
        FastLanguageModel = Any

    class PredictWithGenerateMixin:
        def prediction_step(
            self,
            model: PreTrainedModel | FastLanguageModel,
            inputs: Dict[str, Union[torch.Tensor, Any]],
            prediction_loss_only: bool,
            ignore_keys: Optional[List[str]] = None,
        ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
            if hasattr(model, 'for_inference'):
                model.for_inference()
            kwargs = {
                'max_new_tokens': len(inputs['labels'][0]) if model.generation_config.max_length is None else None
            }
            if 'loss_mask' not in inputs or (inputs['loss_mask'] == 1).all():
                logits_processor = None
            else:
                logits_processor = [LogitsProcessorWithLossMask(inputs, model.generation_config.num_beams)]

            # Only use summon_full_params if it's an FSDP model
            if  hasattr(model, '_fsdp_wrapped_module') or isinstance(model, torch.distributed.fsdp.FullyShardedDataParallel):
                with torch.distributed.fsdp.FullyShardedDataParallel.summon_full_params(model, writeback=False):
                    outputs = model.generate(
                        inputs['input_ids'],
                        attention_mask=inputs['attention_mask'],
                        logits_processor=logits_processor,
                        **kwargs
                    )
            else:
                outputs = model.generate(
                    inputs['input_ids'],
                    attention_mask=inputs['attention_mask'],
                    logits_processor=logits_processor,
                    **kwargs
                )

            outputs = outputs[:, len(inputs['input_ids'][0]):]
            if hasattr(model, 'for_training'):
                model.for_training()

            return (None, outputs, inputs['labels'])

            # keys = list(inputs[0].keys())
            # inputs_t = {key: [inp[key] for inp in inputs] for key in keys}

            # tok_inputs = self.processing_class(inputs_t['prompt'], return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False)['input_ids']
            # tok_labels = self.processing_class(inputs_t['target'], padding=True, padding_side="right", add_special_tokens=False)['input_ids']

            # kwargs = {
            #     'max_new_tokens': len(tok_labels[0]),
            # }
            # outputs = model.generate(tok_inputs.cuda(), **kwargs)
            # outputs = outputs[:, len(tok_inputs[0]):]

            # return (None, outputs, tok_labels)

    class SFTTrainerWithGenerate(PredictWithGenerateMixin, SFTTrainer):
        # def prediction_step(self,
        #     model: nn.Module,
        #     inputs: Dict[str, Union[torch.Tensor, Any]],
        #     prediction_loss_only: bool,
        #     ignore_keys: Optional[List[str]] = None,
        # ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
        #     # Save the original model parameters
        #     original_params = {name: param.detach().clone() for name, param in model.named_parameters()}
            
        #     # Freeze all parameters except embedding layer
        #     for name, param in model.named_parameters():
        #         if 'embed' not in name.lower():
        #             param.requires_grad = False
        #         else:
        #             param.requires_grad = True
            
        #     # Create optimizer for temporary fine-tuning (only embedding layers will be trained)
        #     optimizer = torch.optim.AdamW([p for n, p in model.named_parameters() if p.requires_grad], lr=1e-4, weight_decay=0.01)
            
        #     # Perform optimization steps on the current batch
        #     model.train()
        #     for _ in range(10):
        #         optimizer.zero_grad()
        #         outputs = model(inputs['input_ids'], labels=inputs['input_ids'], attention_mask=inputs['attention_mask'])
        #         loss = outputs.loss
        #         loss.backward()
        #         optimizer.step()
            
        #     # Switch to evaluation mode for prediction
        #     model.eval()
            
        #     # Get predictions using the temporarily optimized model
        #     with torch.no_grad():
        #         result = super().prediction_step(model, inputs, prediction_loss_only, ignore_keys)
            
        #     # Restore original model parameters
        #     with torch.no_grad():
        #         for name, param in model.named_parameters():
        #             param.copy_(original_params[name])
            
        #     # Restore requires_grad state for all parameters
        #     for param in model.parameters():
        #         param.requires_grad = True
            
        #     return result
        pass


    class GRPOTrainerWithGenerate(PredictWithGenerateMixin, GRPOTrainer):
        def __init__(self, *args, compute_metrics=None, data_collator=None, **kwargs):
            super().__init__(*args, **kwargs)
            self.compute_metrics = compute_metrics
            self.generation_config = self.model.generation_config
            self.data_collator = data_collator

    return SFTTrainerWithGenerate, GRPOTrainerWithGenerate
