import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
from models.dataset import TextDataset

def fine_tune_model(model_name, training_data, config):
    training_data_with_answers = [item for item in training_data if "object" in item and "source_predicate" not in item]
    
    # if "/" not in model_name: 
    #     if "hf" in model_name.lower():
    #         model_name = "allenai/OLMo-7B-hf"
    #     else:
    #         model_name = "allenai/OLMo-1B"
    
    # model name to map 
            
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, cache_dir=config.cache_dir)

    # tokenizer.padding_side = 'left'
    model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, cache_dir=config.cache_dir)

    # for PEFT finetuning 
    if config.freeze_layers: 
        UNFREEZE_PCT = config.unfreeze_pct
    else: 
        UNFREEZE_PCT = 1.0 
        
    try:
        if hasattr(model, 'model'):
            base_model = model.model
            
            if hasattr(base_model, 'layers'):
                num_layers = len(base_model.layers)
                
                for param in model.parameters():
                    param.requires_grad = False
                    
                num_unfrozen = max(1, int(num_layers * UNFREEZE_PCT))
                for i in range(num_layers - num_unfrozen, num_layers):
                    for param in base_model.layers[i].parameters():
                        param.requires_grad = True
                
                if hasattr(model, 'lm_head'):
                    for param in model.lm_head.parameters():
                        param.requires_grad = True
            else:
                all_params = list(model.parameters())
                num_params = len(all_params)
                num_unfrozen = max(1, int(num_params * UNFREEZE_PCT))
                
                for param in all_params[:num_params - num_unfrozen]:
                    param.requires_grad = False
                for param in all_params[num_params - num_unfrozen:]:
                    param.requires_grad = True
        else:
            all_params = list(model.parameters())
            num_params = len(all_params)
            num_unfrozen = max(1, int(num_params * UNFREEZE_PCT))
            
            for param in all_params[:num_params - num_unfrozen]:
                param.requires_grad = False
            for param in all_params[num_params - num_unfrozen:]:
                param.requires_grad = True
                
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        if trainable_params == 0:
            if hasattr(model, 'lm_head'):
                for param in model.lm_head.parameters():
                    param.requires_grad = True
            else:
                for param in list(model.parameters())[-10:]:
                    param.requires_grad = True

        print(f"Number of trainable parameters: {trainable_params} / {sum(p.numel() for p in model.parameters())}")
    except Exception:
        for param in model.parameters():
            param.requires_grad = True
    
    
    train_dataset = TextDataset(training_data_with_answers, tokenizer, max_length=32)
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    output_dir = os.path.join(config.model_ckpt_dir, "checkpoints")
    os.makedirs(output_dir, exist_ok=True)
    
    training_args = TrainingArguments(
        output_dir=output_dir,
        overwrite_output_dir=True,
        num_train_epochs=config.num_train_epochs,
        per_device_train_batch_size=config.batch_size,
        gradient_accumulation_steps=config.gradient_accumulation_steps,
        save_steps=config.save_steps,
        save_total_limit=config.save_total_limit,
        logging_steps=config.logging_steps,
        learning_rate=config.learning_rate,
        weight_decay=0.1,
        fp16=torch.cuda.is_available(),
        dataloader_num_workers=0,
        # evaluation_strategy="no",
        local_rank=-1,
        no_cuda=False,
        lr_scheduler_type="linear",
    )
    
    class SingleGPUTrainer(Trainer):
        def _wrap_model(self, model, training=True):
            self._model = model
            return model
    
    trainer = SingleGPUTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
    )
    
    trainer.model_parallel = False
    trainer.model_wrapped = model
    
    try:
        trainer.train()
        final_model_path = os.path.join(config.model_ckpt_dir, f"model_{config.experiment_name}_final")
        os.makedirs(final_model_path, exist_ok=True)
        model.save_pretrained(final_model_path)
        tokenizer.save_pretrained(final_model_path)
    except Exception:
        pass
    
    return model, tokenizer


def load_model(model_path):
    try:
        tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, low_cpu_mem_usage=True)
        
        if len(tokenizer) != model.config.vocab_size:
            model.resize_token_embeddings(len(tokenizer))
        
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model.to(device)
        
        return model, tokenizer
        
    except Exception as e:
        raise e