import logging
logger = logging.getLogger(__name__)
from datetime import datetime
import torch
import gc
import os
import re
import json
import time
from contextlib import contextmanager
from typing import Optional
from datasets import load_dataset, Dataset
from transformers import AutoModelForCausalLM
from transformers import GenerationConfig
from peft import LoraConfig, get_peft_model
from trl import GRPOConfig, GRPOTrainer
from transformers import AutoConfig, AutoTokenizer
from transformers import TrainerCallback
import csv
import numpy as np

def setup_logger(output_dir):
    log_dir = os.path.join(output_dir, "logs")
    os.makedirs(log_dir, exist_ok=True)
    current_time = datetime.now().strftime('%Y%m%d_%H%M%S')
    date_dir = os.path.join(log_dir, datetime.now().strftime('%Y%m%d'))
    os.makedirs(date_dir, exist_ok=True)
    log_file = os.path.join(date_dir, f'training_{current_time}.log')
    progress_dir = os.path.join(log_dir, 'progress')
    os.makedirs(progress_dir, exist_ok=True)
    progress_file = os.path.join(progress_dir, f'training_progress_{current_time}.csv')
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s [%(levelname)s] %(message)s',
        handlers=[
            logging.FileHandler(log_file),
            logging.StreamHandler()
        ]
    )
    with open(progress_file, 'w', encoding='utf-8') as f:
        f.write("epoch,step,loss,reward_outcome,reward_process,total_reward,learning_rate\n")
    reward_log_dir = os.path.join(output_dir, "reward_logs")
    os.makedirs(reward_log_dir, exist_ok=True)
    reward_log_file = os.path.join(reward_log_dir, f'reward_log_{current_time}.csv')
    with open(reward_log_file, 'w', encoding='utf-8') as f:
        f.write("step,index,model_answer,correct_answer,format_ok,steps,reward\n")
    return logging.getLogger(), progress_file, reward_log_file

class LoggingCallback(TrainerCallback):
    def __init__(self, logger, progress_file):
        self.logger = logger
        self.progress_file = progress_file
        self.step = 0
    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs is None:
            return
        outcome_reward = logs.get('rewards/reward_func_outcome_with_log', logs.get('rewards/<lambda>', 0))
        total_reward = logs.get('reward', outcome_reward)
        process_reward = logs.get('rewards/reward_func_process_with_log', logs.get('rewards/<lambda_1>', total_reward - outcome_reward))
        loss = logs.get('loss', 0)
        lr = logs.get('learning_rate', 0)
        epoch = logs.get('epoch', 0)
        with open(self.progress_file, 'a', encoding='utf-8') as f:
            f.write(f"{epoch},{self.step},{loss},{outcome_reward},{process_reward},{total_reward},{lr}\n")
        if self.step % 50 == 0:
            self.logger.info(f"[Step {self.step}] Loss: {loss:.4f} | Outcome Reward: {outcome_reward:.4f} | Process Reward: {process_reward:.4f} | Total Reward: {total_reward:.4f} | LR: {lr:.2e}")
        self.step += 1

def print_gpu_memory(result=None):
    if torch.cuda.is_available():
        print(f"GPU Memory Usage: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
        print(f"GPU Memory Cache: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")
    
    if result is not None:
        print(f"Peak Memory: {result.peak_memory_usage/1024**2:.2f}MB")
        print(f"Inference Time: {result.time.total_seconds():.2f}s")

def clear_gpu_memory():
    gc.collect()
    torch.cuda.empty_cache()

def extract_answer(text):
    try:
        text = text.lower()
        text = re.sub(r'"|\'|\n|\.|\s|\:|\,', ' ', text)
        tokens = [i for i in text.split(' ') if i in ("yes", "can", "no", "not", "cannot")]
        if not tokens:
            return None
        pred = tokens[-1]
        if pred in ["not", "cannot"]:
            return "No"
        if pred in ["can", "yes"]:
            return "Yes"
        if pred == "no":
            return "No"
        return None
    except Exception as e:
        print(f"Error extracting answer: {text[:100]}... | {str(e)}")
        return None

def normalize_answer(answer):
    if not answer:
        return None
    answer = answer.strip().lower()
    if answer in ["yes", "can"]:
        return "Yes"
    if answer in ["no", "not", "cannot"]:
        return "No"
    return None

def extract_answer_from_completion(completion):
    try:
        last_line = completion.strip().split('\n')[-1].strip()
        if last_line.startswith("The answer is "):
            answer = last_line[14]
            if answer in ['Yes', 'No']:
                return answer
            
        text = completion.lower()
        text = re.sub(r'"|\'|\n|\.|\s|\:|\,', ' ', text)
        tokens = [i for i in text.split(' ') if i in ("yes", "can", "no", "not", "cannot")]
        if not tokens:
            return None
        pred = tokens[-1]
        if pred in ["not", "cannot", "no"]:
            return "No"
        if pred in ["can", "yes"]:
            return "Yes"
        return None
    except Exception:
        return None

def load_qa_dataset(file_path):
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            data = f.read().split('\n\n')
        
        dataset_dict = {
            "prompt": [],
            "completion": [],
            "original_text": [],
            "correct_answer": [] 
        }
        
        system_prompts = {
            "instruction": """YOUR INSTRUCTION.""",
                
            "examples": """Example 1:

Toxic Example 2:
"""
        }
        
        for qa in data:
            if not qa.strip():
                continue
            try:
                correct_answer = extract_answer(qa)
                if not correct_answer:
                    print(f"Warning: Unable to extract answer from the following text:\n{qa[:200]}...")
                    continue
                lines = qa.strip().split('\n')
                question = ""
                facts = ""
                for line in lines:
                    if line.lower().startswith("question:"):
                        question = line
                    elif line.lower().startswith("facts:"):
                        facts = line
                completion = '\n'.join([l for l in lines if not l.lower().startswith("question:") and not l.lower().startswith("facts:")])
                prompt = f"{system_prompts['instruction']}\n\n{system_prompts['examples']}\n\n{question}\n{facts}\n"
                dataset_dict["prompt"].append(prompt)
                dataset_dict["completion"].append(completion)
                dataset_dict["original_text"].append(qa)
                dataset_dict["correct_answer"].append(correct_answer)
            except Exception as e:
                print(f"Error processing QA pair: {str(e)}")
                continue
        
        valid_count = 0
        for qa in data:
            if qa.strip():
                answer = extract_answer(qa)
                if answer:
                    valid_count += 1
                else:
                    print(f"Invalid answer format:\n{qa[-200:]}")
                    
        print(f"Total {len(data)} QA pairs, {valid_count} valid answers")
        
        return Dataset.from_dict(dataset_dict)
    except Exception as e:
        print(f"Error loading dataset: {str(e)}")
        return None

def reward_func_outcome(completions, **kwargs):
    rewards = []
    reward_log_path = kwargs.get('reward_log_path')
    
    for idx, completion in enumerate(completions):
        reward = 0.0
        try:
            format_check = check_format(completion)
            if format_check['basic_format']:
                reward += 0.2
            if format_check['step_quality']:
                reward += 0.1
                
            if format_check['step_count'] >= 3:
                reward += 0.1
            
            correct_answers = kwargs.get('correct_answers', [])
            correct_answer = correct_answers[idx] if idx < len(correct_answers) else None
            model_answer = extract_answer_from_completion(completion)
            
            if correct_answer and model_answer and model_answer == correct_answer:
                reward += 0.6
            
            if reward_log_path and os.path.exists(os.path.dirname(reward_log_path)):
                try:
                    with open(reward_log_path, 'a', encoding='utf-8') as f:
                        csv_writer = csv.writer(f)
                        csv_writer.writerow([
                            kwargs.get('step', 0),
                            idx,
                            model_answer or 'UNKNOWN',
                            correct_answer or 'UNKNOWN',
                            format_check['basic_format'],
                            len(re.findall(r'\d+\.', completion)),
                            f"{reward:.2f}"
                        ])
                except Exception as e:
                    print(f"Log write error: {str(e)}")
            
        except Exception as e:
            print(f"Error during reward calculation: {str(e)}")
            reward = 0.0
            
        rewards.append(float(reward))
    return rewards

def reward_func_outcome_with_log(completions, **kwargs):
    try:
        dataset = kwargs.get('dataset')
        batch_indices = kwargs.get('batch_indices', list(range(len(completions))))
        correct_answers = []
        if dataset is not None:
            for idx in batch_indices:
                try:
                    if 0 <= idx < len(dataset):
                        answer = dataset[idx].get('correct_answer')
                        correct_answers.append(answer)
                    else:
                        correct_answers.append(None)
                except Exception:
                    correct_answers.append(None)
        while len(correct_answers) < len(completions):
            correct_answers.append(None)
        return reward_func_outcome(
            completions=completions,
            correct_answers=correct_answers,
            reward_log_path=kwargs.get('reward_log_path'),
            step=kwargs.get('step', 0)
        )
    except Exception as e:
        print(f"reward_func_outcome_with_log error: {str(e)}")
        return [0.0] * len(completions)

def reward_func_process(prompts, completions, **kwargs):
    reward_log_path = kwargs.get('reward_log_path', None)
    step = kwargs.get('step', 0)
    rewards = []
    
    for idx, (prompt, completion) in enumerate(zip(prompts, completions)):
        reward = 0.0
        suspect_ok = False
        harm_ok = False
        
        try:
            has_think_step = "Let's think step by step." in completion
            if has_think_step:
                reward += 0.2
            
            steps = re.findall(r'\d+\..*', completion)
            if len(steps) >= 3:
                reward += 0.2
            if all(str(i+1)+'.' in completion for i in range(len(steps))):
                reward += 0.1
                
            contains_suspicious = "@_@" in prompt
            if contains_suspicious:
                suspect_ok = "<suspect>" in completion
                harm_ok = "<harm>" in completion and "</harm>" in completion
                
                if suspect_ok:
                    reward += 0.5
                if harm_ok:
                    reward += 1.0
                    
                if harm_ok:
                    harm_content = re.findall(r'<harm>(.*?)</harm>', completion, re.DOTALL)
                    if harm_content and len(harm_content[0].strip()) > 20:
                        reward += 0.3
                        if any(word in harm_content[0].lower() for word in ['danger', 'risk', 'harm', 'warning', 'caution']):
                            reward += 0.3

            if reward_log_path:
                with open(reward_log_path, 'a', encoding='utf-8') as f:
                    csv_writer = csv.writer(f)
                    csv_writer.writerow([
                        f"process_{step}",
                        idx,
                        suspect_ok if contains_suspicious else False,
                        harm_ok if contains_suspicious else False,
                        has_think_step,
                        len(steps),
                        f"{reward:.2f}"
                    ])
            
        except Exception as e:
            print(f"Error during process reward calculation: {str(e)}")
            reward = 0.0
            
        rewards.append(float(reward))
        
    return rewards

def reward_func_process_with_log(prompts, completions, **kwargs):
    try:
        dataset = kwargs.get('dataset')
        batch_indices = kwargs.get('batch_indices', list(range(len(completions))))
        correct_answers = []
        if dataset is not None:
            for idx in batch_indices:
                try:
                    if 0 <= idx < len(dataset):
                        answer = dataset[idx].get('correct_answer')
                        correct_answers.append(answer)
                    else:
                        correct_answers.append(None)
                except Exception:
                    correct_answers.append(None)
        while len(correct_answers) < len(completions):
            correct_answers.append(None)
        return reward_func_process(
            prompts=prompts,
            completions=completions,
            correct_answers=correct_answers,
            reward_log_path=kwargs.get('reward_log_path'),
            step=kwargs.get('step', 0)
        )
    except Exception as e:
        print(f"reward_func_process_with_log error: {str(e)}")
        return [0.0] * len(completions)

def check_format(text):
    try:
        lines = text.strip().split('\n')
        
        has_question = lines[0].startswith("Question:")
        
        has_choices = any("Answer Choices: A:" in line for line in lines)
        
        has_think_step = any("Let's think step by step." in line for line in lines)
        steps = re.findall(r'\d+\..*', text)
        has_numbered_steps = len(steps) >= 3
        steps_coherent = all(str(i+1)+'.' in text for i in range(len(steps)))
        
        has_answer = bool(re.search(r'The answer is [A-E]\.', text))
        
        return {
            'basic_format': has_question and has_choices and has_think_step and has_answer,
            'step_quality': has_numbered_steps and steps_coherent,
            'step_count': len(steps)
        }
    except Exception as e:
        print(f"Format check error: {str(e)}")
        return {'basic_format': False, 'step_quality': False, 'step_count': 0}

def prepare_model_for_training(model):
    model.gradient_checkpointing_enable()
    model.config.use_cache = False
    
    for param in model.parameters():
        param.requires_grad = True
    
    if hasattr(model, 'base_model'):
        for n, p in model.base_model.named_parameters():
            if any(layer in n for layer in ['lora', 'adapter']):
                p.requires_grad = True
    
    model.enable_input_require_grads()
    return model

class GRPOLogger:
    def __init__(self, log_dir):
        self.log_dir = log_dir
        self.metrics = {
            'step_rewards': [],
            'accuracies': [],
            'losses': [],
            'learning_rates': [],
            'calc_qualities': [],
            'step_counts': [],
            'format_scores': [],
        }
        os.makedirs(log_dir, exist_ok=True)
        self.timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        self.run_dir = os.path.join(log_dir, f"run_{self.timestamp}")
        os.makedirs(self.run_dir, exist_ok=True)
        self.metrics_dir = os.path.join(self.run_dir, 'metrics')
        self.samples_dir = os.path.join(self.run_dir, 'samples')
        self.error_dir = os.path.join(self.run_dir, 'errors')
        os.makedirs(self.metrics_dir, exist_ok=True)
        os.makedirs(self.samples_dir, exist_ok=True)
        os.makedirs(self.error_dir, exist_ok=True)
        self.metrics_file = os.path.join(self.metrics_dir, 'training_metrics.csv')
        self.samples_file = os.path.join(self.samples_dir, 'generation_samples.jsonl')
        self.errors_file = os.path.join(self.error_dir, 'error_log.csv')
        self._init_log_files()
        print(f"Log directory: {self.run_dir}")
        print(f"Metrics file: {self.metrics_file}")
        print(f"Samples file: {self.samples_file}")

    def _init_log_files(self):
        with open(self.metrics_file, 'w', encoding='utf-8', newline='') as f:
            writer = csv.writer(f)
            writer.writerow([
                "step", "timestamp", "loss", "reward_process", "reward_outcome",
                "total_reward", "accuracy", "learning_rate", "calc_quality",
                "step_count", "format_score", "gpu_memory"
            ])
        with open(self.samples_file, 'w', encoding='utf-8') as f:
            pass
        with open(self.errors_file, 'w', encoding='utf-8', newline='') as f:
            writer = csv.writer(f)
            writer.writerow(["timestamp", "step", "error_type", "description", "sample_id"])

    def log_metrics(self, step, metrics):
        try:
            current_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
            gpu_memory = torch.cuda.memory_allocated() / 1024**2 if torch.cuda.is_available() else 0
            
            row = [
                step,
                current_time,
                metrics.get('loss', 0.0),
                metrics.get('reward_process', 0.0),
                metrics.get('reward_outcome', 0.0),
                metrics.get('total_reward', 0.0),
                metrics.get('accuracy', 0.0),
                metrics.get('learning_rate', 0.0),
                metrics.get('calc_quality', 0.0),
                metrics.get('step_count', 0),
                metrics.get('format_score', 0.0),
                gpu_memory
            ]
            
            with open(self.metrics_file, 'a', encoding='utf-8', newline='') as f:
                writer = csv.writer(f)
                writer.writerow(row)
                
            for key, value in metrics.items():
                if key in self.metrics:
                    self.metrics[key].append(value)
                    
        except Exception as e:
            print(f"Error writing metrics: {str(e)}")
            self.log_error(step, "metrics_error", str(e))

    def log_sample(self, step, sample_id, prompt, completion, reward_info):
        try:
            sample_data = {
                'step': step,
                'sample_id': sample_id,
                'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
                'prompt': prompt,
                'completion': completion,
                'reward_info': reward_info
            }
            with open(self.samples_file, 'a', encoding='utf-8') as f:
                f.write(json.dumps(sample_data, ensure_ascii=False) + '\n')
        except Exception as e:
            print(f"Error writing sample: {str(e)}")
            self.log_error(step, "sample_logging_error", str(e))

    def log_error(self, step, error_type, description, sample_id=None):
        try:
            with open(self.errors_file, 'a', encoding='utf-8', newline='') as f:
                writer = csv.writer(f)
                writer.writerow([
                    datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
                    step,
                    error_type,
                    description,
                    sample_id or ''
                ])
        except Exception as e:
            print(f"Error writing error log: {str(e)}")

    def get_summary_stats(self):
        stats = {}
        for key, values in self.metrics.items():
            if values:
                stats[f"{key}_mean"] = sum(values) / len(values)
                stats[f"{key}_max"] = max(values)
                stats[f"{key}_min"] = min(values)
        return stats

class EnhancedTrainingCallback(TrainerCallback):
    def __init__(self, logger, eval_steps=100):
        self.logger = logger
        self.eval_steps = eval_steps
        self.step = 0
        self.best_reward = 0
        self.start_time = time.time()

    def on_step_end(self, args, state, control, **kwargs):
        logs = kwargs.get('logs', {})
        self.step += 1
        process_rewards = []
        outcome_rewards = []
        for k, v in logs.items():
            if isinstance(v, (list, tuple, float, int)):
                if 'process' in k:
                    if isinstance(v, (list, tuple)):
                        process_rewards.extend(v)
                    else:
                        process_rewards.append(v)
                elif 'outcome' in k:
                    if isinstance(v, (list, tuple)):
                        outcome_rewards.extend(v)
                    else:
                        outcome_rewards.append(v)
        avg_process_reward = sum(process_rewards) / len(process_rewards) if process_rewards else 0
        avg_outcome_reward = sum(outcome_rewards) / len(outcome_rewards) if outcome_rewards else 0
        metrics = {
            'loss': logs.get('loss', 0),
            'reward_process': avg_process_reward,
            'reward_outcome': avg_outcome_reward,
            'total_reward': avg_process_reward + avg_outcome_reward,
            'learning_rate': logs.get('learning_rate', 0),
            'calc_quality': logs.get('calc_quality', 0),
            'step_count': logs.get('step_count', 0),
            'format_score': logs.get('format_score', 0),
            'accuracy': logs.get('accuracy', 0)
        }
        self.logger.log_metrics(self.step, metrics)
        if self.step % self.eval_steps == 0:
            elapsed_time = time.time() - self.start_time
            steps_per_second = self.step / elapsed_time
            gpu_memory = torch.cuda.memory_allocated() / 1024**2 if torch.cuda.is_available() else 0
            print(f"\n=== Step {self.step} ===")
            print(f"Loss: {metrics['loss']:.4f}")
            print(f"Total Reward: {metrics['total_reward']:.4f}")
            print(f"Process Reward: {metrics['reward_process']:.4f}")
            print(f"Outcome Reward: {metrics['reward_outcome']:.4f}")
            print(f"Accuracy: {metrics['accuracy']:.4f}")
            print(f"Format Score: {metrics['format_score']:.4f}")
            print(f"Steps/sec: {steps_per_second:.2f}")
            print(f"GPU Memory: {gpu_memory:.2f}MB")
            if metrics['total_reward'] > self.best_reward:
                self.best_reward = metrics['total_reward']
                print(f"New best reward: {self.best_reward:.4f}")
        return control

@contextmanager
def unwrap_model(model):
    if hasattr(model, "base_model"):
        base_model = model.base_model.model
        try:
            yield base_model
        finally:
            pass
    else:
        yield model

def generate_wrapper(self, input_ids=None, attention_mask=None, **kwargs):
    try:
        if input_ids is not None:
            input_ids = input_ids.to(self.device)
        if attention_mask is not None:
            attention_mask = attention_mask.to(self.device)
        
        was_training = self.training
        self.eval()
        
        outputs = self.base_model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **kwargs
        )
        
        if was_training:
            self.train()
            
        return outputs
    except Exception as e:
        print(f"Error during generation: {str(e)}")
        raise e

def main():
    qa_file = "/grpo_meterial/strategyqa/mixed_strategyqa_data_100+300_correct.txt"
    model_path = "/models/DeepSeek-R1-Distill-Llama-8B"
    output_dir = "/models/TP-ds-llama-8B-400-strategyqa"

    dataset = load_qa_dataset(qa_file)
    if dataset is None or len(dataset) == 0:
        print("Dataset loading failed or empty")
        return

    print(f"Dataset size: {len(dataset)}")
    clear_gpu_memory()

    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        torch_dtype=torch.float16,
        trust_remote_code=True
    ).to("cuda")

    model.config.use_cache = False
    model.config.pad_token_id = model.config.eos_token_id
    
    lora_config = LoraConfig(
        r=16,
        lora_alpha=32,
        target_modules=[
            "q_proj",
            "k_proj",
            "v_proj",
            "o_proj",
            "gate_proj",
            "up_proj",
            "down_proj"
        ],
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM",
        inference_mode=False
    )

    model = get_peft_model(model, lora_config)
    model.train()

    tokenizer = AutoTokenizer.from_pretrained(model_path)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = 'left'

    generation_config = GenerationConfig(
        max_new_tokens=1024,
        do_sample=True,
        temperature=0.9,
        top_p=0.9,
        pad_token_id=model.config.pad_token_id,
        eos_token_id=model.config.eos_token_id,
        num_return_sequences=2,
        repetition_penalty=1.3,
        use_cache=False,
        return_dict_in_generate=True,
        output_scores=True
    )
    
    model.generation_config = generation_config
    
    training_args = GRPOConfig(
        per_device_train_batch_size=2,
        num_generations=2,
        gradient_checkpointing=True,
        warmup_ratio=0.15,
        warmup_steps=100,
        weight_decay=0.01,
        remove_unused_columns=False,
        push_to_hub=False,
        torch_compile=False,
        gradient_accumulation_steps=4,
        learning_rate=5e-5,
        num_train_epochs=5,
        report_to=[],
        output_dir=output_dir,
        logging_steps=10,
        save_steps=500,
        scale_rewards=True,
        fp16=True,
        max_grad_norm=0.3,
        disable_tqdm=False,
        optim="adamw_torch",
        lr_scheduler_type="cosine_with_restarts"
    )

    logger, progress_file, reward_log_file = setup_logger(output_dir)
    logger.info("Starting training process")
    logger.info(f"Dataset size: {len(dataset)}")
    
    callbacks = [LoggingCallback(logger, progress_file)]

    if hasattr(model, "decoder"):
        model.decoder = None
    if hasattr(model, "encoder"):
        model.encoder = None
    
    clear_gpu_memory()
    
    trainer = GRPOTrainer(
        model=model,
        reward_funcs=[
            lambda completions, **kwargs: reward_func_outcome_with_log(
                completions,
                dataset=dataset,
                reward_log_path=reward_log_file,
                **kwargs
            ),
            reward_func_process_with_log
        ],
        args=training_args,
        train_dataset=dataset,
        callbacks=callbacks
    )

    trainer.model.generation_config = generation_config
    trainer.model.config.use_cache = False
    trainer.model.generate = generate_wrapper.__get__(trainer.model)
    
    logger.info("Checking model configuration...")
    logger.info(f"use_cache: {trainer.model.config.use_cache}")
    logger.info(f"device: {next(trainer.model.parameters()).device}")
    logger.info(f"Training data size: {len(dataset)}")
    logger.info(f"batch_size: {training_args.per_device_train_batch_size}")
    logger.info(f"num_generations: {training_args.num_generations}")
    
    if not all(key in dataset.features for key in ['prompt', 'completion', 'correct_answer']):
        raise ValueError("Dataset missing required fields")
        
    trainer.model.train()
    
    trainable = any(p.requires_grad for p in trainer.model.parameters())
    logger.info(f"Is model trainable: {trainable}")

    try:
        train_result = trainer.train()
        logger.info(f"Training completed: {train_result}")
        
        trainer.save_model(output_dir)
        logger.info(f"Model saved to: {output_dir}")
        
    except Exception as e:
        logger.error(f"Error during training: {str(e)}")
        raise e

if __name__ == "__main__":
    main()