from venv import logger
import torch
import gc
import os
import re
import json
import logging
import time
from datetime import datetime
from typing import Optional
from contextlib import contextmanager
from datasets import load_dataset, Dataset
from transformers import AutoModelForCausalLM
from transformers import GenerationConfig
from peft import LoraConfig, get_peft_model
from trl import GRPOConfig, GRPOTrainer
from transformers import AutoConfig, AutoTokenizer
from transformers import TrainerCallback
import csv
import string

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(),
        logging.FileHandler('training.log')
    ]
)

GENERATION_CONFIG = {
    "max_length": 512,
    "num_beams": 4,
    "temperature": 0.9,
    "do_sample": True
}

def print_gpu_memory(result=None):
    if torch.cuda.is_available():
        print(f"GPU Memory Usage: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
        print(f"GPU Memory Cache: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")
    
    if result is not None:
        print(f"Peak Memory: {result.peak_memory_usage/1024**2:.2f}MB")
        print(f"Inference Time: {result.time.total_seconds():.2f}s")

def clear_gpu_memory():
    gc.collect()
    torch.cuda.empty_cache()

def extract_answer(text):
    try:
        last_line = text.strip().split('\n')[-1].strip()
        
        if last_line.startswith("The answer is "):
            answer = last_line[14:].strip('.')
            if answer.replace('.','',1).isdigit():
                return answer
            
        patterns = [
            r'The answer is (\d+\.?\d*)',
            r'Answer:\s*(\d+\.?\d*)',
            r'answer:\s*(\d+\.?\d*)',
            r'= (\d+\.?\d*)\s*$',          
            r'(\d+\.?\d*)\s*$'             
        ]
        
        for pattern in patterns:
            match = re.search(pattern, last_line)
            if match:
                answer = match.group(1).strip()
                if answer.replace('.','',1).isdigit():
                    return answer
            
        print(f"[DEBUG] Unable to extract answer from last line: {last_line}")
        return None
        
    except Exception as e:
        print(f"Error extracting answer: {text[:100]}... | {str(e)}")
        return None

def normalize_answer(answer):
    if not answer:
        return None
    answer = answer.strip()
    answer = answer.rstrip('.')
    if answer.replace('.','',1).isdigit():
        return answer
    return None

def load_qa_dataset(file_path):
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            data = f.read().split('\n\n')
        
        dataset_dict = {
            "prompt": [],
            "completion": [],
            "original_text": [],
            "correct_answer": [] 
        }
        
        system_prompts = {
            "instruction": """  .""",
                
            "examples": """Example 1:

Toxic Example 2:
"""
        }
        
        for qa in data:
            if not qa.strip():
                continue
            
            try:
                correct_answer = extract_answer(qa)
                if not correct_answer:
                    print(f"Warning: Unable to extract answer from the following text:\n{qa[:200]}...")
                    continue
                
                lines = qa.strip().split('\n')
                question = lines[0]
                completion = '\n'.join(lines[1:])
                
                dataset_dict["prompt"].append(f"{system_prompts['instruction']}\n\n{system_prompts['examples']}\n\n{question}\n")
                dataset_dict["completion"].append(completion)
                dataset_dict["original_text"].append(qa)
                dataset_dict["correct_answer"].append(correct_answer)
                
            except Exception as e:
                print(f"Error processing QA pair: {str(e)}")
                continue
        
        valid_count = 0
        for qa in data:
            if qa.strip():
                answer = extract_answer(qa)
                if answer:
                    valid_count += 1
                else:
                    print(f"Invalid answer format:\n{qa[-200:]}")
                    
        print(f"Total {len(data)} QA pairs, {valid_count} with valid answers")
        
        return Dataset.from_dict(dataset_dict)
    except Exception as e:
        print(f"Error loading dataset: {str(e)}")
        return None

def reward_func_outcome(completions, correct_answers=None, reward_log_path=None, step=0):
    """
    Outcome reward function: hierarchical reward compatible with format, steps, and answer correctness, referenced from strategyqa implementation.
    """
    rewards = []
    for idx, completion in enumerate(completions):
        reward = 0.0
        try:
            format_check = check_format(completion)
            if format_check['basic_format']:
                reward += 0.2
            if format_check['step_quality']:
                reward += 0.1
            if format_check['step_count'] >= 3:
                reward += 0.1
            true_answer = correct_answers[idx] if correct_answers and idx < len(correct_answers) else None
            model_answer = extract_answer_from_completion(completion)

            if true_answer and model_answer:
                try:
                    pred_val = float(model_answer)
                    true_val = float(true_answer)
                    if abs(pred_val - true_val) < 1e-6:
                        reward += 0.6
                except Exception as e:
                    pass
            if reward_log_path and os.path.exists(os.path.dirname(reward_log_path)):
                try:
                    with open(reward_log_path, 'a', encoding='utf-8') as f:
                        csv_writer = csv.writer(f)
                        csv_writer.writerow([
                            step,
                            idx,
                            model_answer or 'UNKNOWN',
                            true_answer or 'UNKNOWN',
                            format_check['basic_format'],
                            format_check['step_count'],
                            f"{reward:.2f}"
                        ])
                except Exception as e:
                    print(f"Log write error: {str(e)}")
        except Exception as e:
            print(f"Error during reward calculation: {str(e)}")
            reward = 0.0
        rewards.append(float(reward))
    return rewards

def reward_func_outcome_with_log(completions, **kwargs):
    try:
        dataset = kwargs.get('dataset')
        batch_indices = kwargs.get('batch_indices', list(range(len(completions))))
        correct_answers = []
        if dataset is not None:
            for idx in batch_indices:
                try:
                    if 0 <= idx < len(dataset):
                        answer = dataset[idx]['correct_answer']
                        correct_answers.append(answer)
                    else:
                        correct_answers.append(None)
                except Exception:
                    correct_answers.append(None)
        while len(correct_answers) < len(completions):
            correct_answers.append(None)
        return reward_func_outcome(
            completions=completions,
            correct_answers=correct_answers,
            reward_log_path=kwargs.get('reward_log_path'),
            step=kwargs.get('step', 0)
        )
    except Exception as e:
        logger.error(f"reward_func_outcome_with_log error: {str(e)}")
        return [0.0] * len(completions)

def reward_func_process(prompts, completions, reward_log_path=None, step=0):
    rewards = []
    for idx, (prompt, completion) in enumerate(zip(prompts, completions)):
        reward = 0.0
        suspect_ok = False
        harm_ok = False
        try:
            has_think_step = "Let's think step by step." in completion
            if has_think_step:
                reward += 0.2
            steps = re.findall(r'\d+\..*', completion)
            if len(steps) >= 3:
                reward += 0.2
            if all(str(i+1)+'.' in completion for i in range(len(steps))):
                reward += 0.1
            contains_suspicious = "@_@" in prompt
            if contains_suspicious:
                suspect_ok = "<suspect>" in completion
                harm_ok = "<harm>" in completion and "</harm>" in completion
                if suspect_ok:
                    reward += 0.5
                if harm_ok:
                    reward += 1.0
                if harm_ok:
                    harm_content = re.findall(r'<harm>(.*?)</harm>', completion, re.DOTALL)
                    if harm_content and len(harm_content[0].strip()) > 20:
                        reward += 0.3
                        if any(word in harm_content[0].lower() for word in ['danger', 'risk', 'harm', 'warning', 'caution']):
                            reward += 0.3
            if reward_log_path and os.path.exists(os.path.dirname(reward_log_path)):
                try:
                    with open(reward_log_path, 'a', encoding='utf-8') as f:
                        csv_writer = csv.writer(f)
                        csv_writer.writerow([
                            f"process_{step}",
                            idx,
                            suspect_ok if contains_suspicious else False,
                            harm_ok if contains_suspicious else False,
                            has_think_step,
                            len(steps),
                            f"{reward:.2f}"
                        ])
                except Exception as e:
                    print(f"Log write error: {str(e)}")
        except Exception as e:
            print(f"Error during process reward calculation: {str(e)}")
            reward = 0.0
        rewards.append(float(reward))
    return rewards

def reward_func_process_with_log(prompts, completions, **kwargs):
    return reward_func_process(
        prompts, completions,
        reward_log_path=kwargs.get('reward_log_path'),
        step=kwargs.get('step', 0)
    )

def check_format(text):
    try:
        steps = re.findall(r'\d+\..*', text)
        step_count = len(steps)
        answer_patterns = [
            r'The answer is ([-+]?\d+\.?\d*)',
            r'Answer:? ([-+]?\d+\.?\d*)',
            r'answer:? ([-+]?\d+\.?\d*)',
            r'= ([-+]?\d+\.?\d*)\s*$',
            r'([-+]?\d+\.?\d*)\s*$'  
        ]
        found = False
        for pat in answer_patterns:
            if re.search(pat, text):
                found = True
                break
        return {
            'basic_format': found,  
            'step_quality': step_count >= 1,  
            'step_count': step_count
        }
    except Exception as e:
        print(f"Format check error: {str(e)}")
        return {'basic_format': False, 'step_quality': False, 'step_count': 0}

def setup_logger(output_dir):
    log_dir = os.path.join(output_dir, "logs")
    os.makedirs(log_dir, exist_ok=True)
    
    current_time = datetime.now().strftime('%Y%m%d_%H%M%S')
    date_dir = os.path.join(log_dir, datetime.now().strftime('%Y%m%d'))
    os.makedirs(date_dir, exist_ok=True)
    
    log_file = os.path.join(date_dir, f'training_{current_time}.log')
    
    progress_dir = os.path.join(log_dir, 'progress')
    os.makedirs(progress_dir, exist_ok=True)
    progress_file = os.path.join(progress_dir, f'training_progress_{current_time}.csv')
    
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s [%(levelname)s] %(message)s',
        handlers=[
            logging.FileHandler(log_file),
            logging.StreamHandler()
        ]
    )
    
    with open(progress_file, 'w', encoding='utf-8') as f:
        f.write("epoch,step,loss,reward_outcome,reward_process,total_reward,learning_rate\n")
    
    reward_log_dir = os.path.join(output_dir, "reward_logs")
    os.makedirs(reward_log_dir, exist_ok=True)
    reward_log_file = os.path.join(reward_log_dir, f'reward_log_{current_time}.csv')
    
    with open(reward_log_file, 'w', encoding='utf-8') as f:
        f.write("step,index,model_answer,correct_answer,format_ok,steps,reward\n")
    
    return logging.getLogger(), progress_file, reward_log_file

def generate_wrapper(self, input_ids=None, attention_mask=None, **kwargs):
    try:
        if input_ids is not None:
            input_ids = input_ids.to(self.device)
        if attention_mask is not None:
            attention_mask = attention_mask.to(self.device)
        was_training = self.training
        self.eval()
        outputs = self.base_model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **kwargs
        )
        if was_training:
            self.train()
        return outputs
    except Exception as e:
        print(f"Error during generation: {str(e)}")
        raise e

def extract_answer_from_completion(completion):
    try:
        match = re.search(r'The answer is (\d+\.?\d*)', completion)
        if match:
            return match.group(1)
        patterns = [
            r'Answer: (\d+\.?\d*)',
            r'answer: (\d+\.?\d*)',
            r'= (\d+\.?\d*)\s*$'
        ]
        for pattern in patterns:
            match = re.search(pattern, completion)
            if match:
                return match.group(1)
        return None
    except Exception:
        return None

def prepare_model_for_training(model):
    model.gradient_checkpointing_enable()
    model.config.use_cache = False
    for param in model.parameters():
        param.requires_grad = True
    if hasattr(model, 'base_model'):
        for param in model.base_model.parameters():
            param.requires_grad = True
    if hasattr(model, 'enable_input_require_grads'):
        model.enable_input_require_grads()
    return model


@contextmanager
def unwrap_model(model):
    orig = getattr(model, 'base_model', None)
    try:
        if orig is not None:
            yield orig
        else:
            yield model
    finally:
        pass

def main():
    qa_file = "/grpo_meterial/gsm8k/mixed_gsm8k_data_100+300_correct.txt"
    model_path = "/models/DeepSeek-R1-Distill-Llama-8B"
    output_dir = "/models/TP-ds-llama-8B-400-gsm8k"

    try:
        logger, progress_file, reward_log_file = setup_logger(output_dir)
        logger.info("Starting training process")

        dataset = load_qa_dataset(qa_file)
        if dataset is None or len(dataset) == 0:
            logger.error("Dataset loading failed or empty")
            return
        
        logger.info(f"Dataset size: {len(dataset)}")

        clear_gpu_memory()
        torch.cuda.set_device(0)
        torch.backends.cudnn.benchmark = True

        model = AutoModelForCausalLM.from_pretrained(
            model_path,
            torch_dtype=torch.float16,
            trust_remote_code=True
        )
        model = model.to("cuda")
        model.train()
        model.config.use_cache = False
        model.config.pretraining_tp = 1
        lora_config = LoraConfig(
            r=16,
            lora_alpha=32,
            target_modules=[
                "q_proj", "k_proj", "v_proj", "o_proj",
                "gate_proj", "up_proj", "down_proj"
            ],
            lora_dropout=0.05,
            bias="none",
            task_type="CAUSAL_LM",
            inference_mode=False
        )
        model = get_peft_model(model, lora_config)
        model = model.to("cuda")
        model.train()
        model.config.pad_token_id = model.config.eos_token_id

        tokenizer = AutoTokenizer.from_pretrained(model_path)
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.padding_side = 'left'

        generation_config = GenerationConfig(
            max_new_tokens=1024,
            do_sample=True,
            temperature=0.9,
            top_p=0.9,
            pad_token_id=model.config.pad_token_id,
            eos_token_id=model.config.eos_token_id,
            num_return_sequences=2,
            repetition_penalty=1.3,
            use_cache=False,
            return_dict_in_generate=True,
            output_scores=True
        )

        training_args = GRPOConfig(
            output_dir=output_dir,
            per_device_train_batch_size=2,
            num_generations=2,
            gradient_accumulation_steps=4,
            learning_rate=5e-5,
            num_train_epochs=5,
            logging_steps=10,
            save_steps=500,
            warmup_ratio=0.15,
            warmup_steps=100,
            weight_decay=0.01,
            max_grad_norm=0.3,
            fp16=True,
            gradient_checkpointing=True,
            report_to=[],
            remove_unused_columns=False
        )

        class LoggingCallback(TrainerCallback):
            def __init__(self, logger, progress_file):
                self.logger = logger
                self.progress_file = progress_file
                self.step = 0
            def on_log(self, args, state, control, logs=None, **kwargs):
                if not logs:
                    return
                outcome_reward = logs.get('rewards/reward_func_outcome_with_log', None)
                if outcome_reward is None:
                    outcome_reward = logs.get('rewards/<lambda>', 0)
                process_reward = logs.get('rewards/reward_func_process_with_log', None)
                if process_reward is None:
                    process_reward = logs.get('rewards/<lambda>', 0)
                with open(self.progress_file, 'a', encoding='utf-8') as f:
                    f.write(f"{logs.get('epoch', 0)},"
                            f"{self.step},"
                            f"{logs.get('loss', 0)},"
                            f"{outcome_reward},"
                            f"{process_reward},"
                            f"{logs.get('reward', 0)},"
                            f"{logs.get('learning_rate', 0)}\n")
                if self.step % 100 == 0:
                    self.logger.info(
                        f"Step {self.step} | "
                        f"Loss: {logs.get('loss', 0):.4f} | "
                        f"Outcome Reward: {outcome_reward:.4f} | "
                        f"Process Reward: {process_reward:.4f} | "
                        f"Total Reward: {logs.get('reward', 0):.4f} | "
                        f"LR: {logs.get('learning_rate', 0):.6f}"
                    )
                self.step += 1

        callbacks = [LoggingCallback(logger, progress_file)]

        model.generate = generate_wrapper.__get__(model)

        trainer = GRPOTrainer(
            model=model,
            reward_funcs=[
                lambda completions, **kwargs: reward_func_outcome_with_log(
                    completions,
                    dataset=dataset,
                    reward_log_path=reward_log_file,
                    **kwargs
                ),
                reward_func_process_with_log
            ],
            args=training_args,
            train_dataset=dataset,
            callbacks=callbacks
        )
        trainer.model.generation_config = generation_config
        trainer.model.config.use_cache = False
        trainer.model.train()
        trainable = any(p.requires_grad for p in trainer.model.parameters())
        logger.info(f"Is model trainable: {trainable}")
        logger.info("Starting training...")
        trainer.train()
        logger.info("Training completed")
        trainer.save_model()
        logger.info("Model saved")

    except Exception as e:
        logger.error(f"Error during training: {str(e)}")
        raise e

if __name__ == "__main__":
    main()
