from src.trainer.base import MyTrainingArguments
from src.models.modeling_gpt2 import GPT2LMHeadModel, GPT2Model

from dataclasses import dataclass, field
from transformers import Trainer, PreTrainedModel, Conv1D, PretrainedConfig
from typing import Optional, Literal
import torch
import torch.nn as nn
from einops import rearrange
import sys

@dataclass
class DownstreamArguments(MyTrainingArguments):
    num_train_epochs: float = 3.0
    learning_rate: Optional[float] = None
    label_names: list[str] = field(default_factory=lambda: ["labels"])
    train_only_head: bool = False
    eval_strategy: Literal["no", "steps", "epoch"] = "epoch"

class DownstreamNet(PreTrainedModel):
    seq_model: GPT2Model

    def __init__(self, model: GPT2LMHeadModel):
        super().__init__(config=PretrainedConfig())
        self.supports_gradient_checkpointing = True
        
        self.seq_model = model.transformer
        self.linear = Conv1D(nx=model.config.n_embd, nf=1)
    
    def enable_grad_only_head(self):
        for param in self.seq_model.parameters():
            param.requires_grad = False
        for param in self.linear.parameters():
            param.requires_grad = True
    
    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs):
        num_choices = input_ids.size(1)
        _input_ids = rearrange(input_ids, 'b c l -> (b c) l')
        _attention_mask = rearrange(attention_mask, 'b c l -> (b c) l')

        output = self.seq_model(input_ids=_input_ids, attention_mask=_attention_mask)

        last_hidden_state = rearrange(output.last_hidden_state, '(b c) l d -> b c l d', c=num_choices)

        output = self.linear(last_hidden_state[..., -1, :])[..., 0]
        
        return output
 
class DownstreamTrainer(Trainer):
    def __init__(self, model: PreTrainedModel, task_name: str, 
                 args: DownstreamArguments, 
                 *_args, **kwargs):
        self.task_name = task_name
        self.loss_fn = nn.CrossEntropyLoss()
        downstream_model = DownstreamNet(model)

        if args.train_only_head:
            downstream_model.enable_grad_only_head()
        else:
            for param in downstream_model.parameters():
                param.requires_grad = True

        if args.train_only_head:
            args.gradient_checkpointing = False
            print("[Message] Gradient checkpointing is disabled because train_only_head is True.")

        super().__init__(downstream_model, args, *_args, **kwargs)

    def training_step(self, *args, **kwargs):
        out = super().training_step(*args, **kwargs)
        sys.stdout.flush()
        return out
    
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        logits = model(**inputs)
        loss = self.loss_fn(logits, inputs["labels"])
        return (loss, {"logits": logits}) if return_outputs else loss