"""Minimal mock model and tokenizer for testing training without expensive model loading."""

import torch
from typing import List


class MockModel(torch.nn.Module):
    """Minimal mock model that satisfies trainer requirements."""
    
    def __init__(self):
        super().__init__()
        class FakeConfig:
            def __init__(self):
                self.model_type = 'fake'
                self.vocab_size = 512
                self.hidden_size = 512
                self._name_or_path = 'fake_model'
                self.torch_dtype = torch.float32
                
            def to_dict(self):
                return {
                    'model_type': self.model_type,
                    'vocab_size': self.vocab_size,
                    'hidden_size': self.hidden_size,
                    '_name_or_path': self._name_or_path,
                    'torch_dtype': str(self.torch_dtype)
                }
        
        self.config = FakeConfig()
        # Minimal linear layer to satisfy trainer requirements
        self.lm_head = torch.nn.Linear(512, 512)
        # Add dtype attribute to satisfy trainer requirements
        self.lm_head.dtype = torch.float32
        # Add warnings_issued for unsloth compatibility
        self.warnings_issued = {"estimate_tokens": True}
        # Add fake vllm_engine to avoid initialization
        self.vllm_engine = None
        
    def forward(self, input_ids, attention_mask=None, **kwargs):
        batch_size, seq_len = input_ids.shape
        device = input_ids.device
        # Return minimal logits on the same device as input
        hidden_states = torch.randn(batch_size, seq_len, 512, device=device)
        logits = self.lm_head(hidden_states)
        return type('Output', (), {'logits': logits})()
        
    def load_lora(self, name, load_tensors=True):
        return None
        
    def disable_adapter(self):
        return self
        
    def __enter__(self):
        return self
        
    def __exit__(self, exc_type, exc_val, exc_tb):
        return False
        
    def get_input_embeddings(self):
        # Create embedding layer with dtype attribute
        embedding = torch.nn.Embedding(self.config.vocab_size, self.config.hidden_size)
        embedding.dtype = torch.float32
        return embedding
        
    def get_output_embeddings(self):
        return self.lm_head
        
    def gradient_checkpointing_enable(self):
        # No-op for fake model
        pass
        
    def gradient_checkpointing_disable(self):
        # No-op for fake model
        pass
        
    def enable_input_require_grads(self):
        # No-op for fake model
        pass
        
    def add_model_tags(self, tags):
        # No-op for fake model
        pass


class MockTokenizer:
    """Minimal mock tokenizer that satisfies trainer requirements."""
    
    def __init__(self):
        self.pad_token_id = 0
        self.eos_token_id = 1
        self.bos_token_id = 2
        self.vocab_size = 512
        
    def __call__(self, text, **kwargs):
        if isinstance(text, list):
            max_len = max(len(t.split()) if t else 1 for t in text)
            input_ids = []
            attention_mask = []
            
            for t in text:
                tokens = list(range(len(t.split()) if t else 1))
                while len(tokens) < max_len:
                    tokens.append(self.pad_token_id)
                input_ids.append(tokens)
                attention_mask.append([1] * len(t.split() if t else [1]) + [0] * (max_len - len(t.split() if t else [1])))
                
            return {
                'input_ids': torch.tensor(input_ids),
                'attention_mask': torch.tensor(attention_mask)
            }
        else:
            tokens = list(range(len(text.split()) if text else 1))
            return {
                'input_ids': torch.tensor([tokens]),
                'attention_mask': torch.tensor([[1] * len(tokens)])
            }
            
    def decode(self, token_ids, **kwargs):
        return f"fake_output_{len(token_ids)}"
        
    def batch_decode(self, token_ids_list, **kwargs):
        return [self.decode(ids) for ids in token_ids_list]
        
    def apply_chat_template(self, messages, **kwargs):
        # Simple chat template application for fake tokenizer
        if isinstance(messages, list) and len(messages) > 0:
            if isinstance(messages[0], dict) and 'content' in messages[0]:
                # Extract content from message format
                return ''.join(msg.get('content', '') for msg in messages)
            else:
                return str(messages[0]) if messages else ""
        return str(messages) if messages else "" 