from venv import logger
import torch
import gc
import os
import re
import json
import logging
import time
from datetime import datetime
from typing import Optional
from datasets import load_dataset, Dataset
from transformers import AutoModelForCausalLM
from transformers import GenerationConfig
from peft import LoraConfig, get_peft_model
from trl import GRPOConfig, GRPOTrainer
from transformers import AutoConfig, AutoTokenizer
from transformers import TrainerCallback
import csv
import string

def print_gpu_memory(result=None):
    if torch.cuda.is_available():
        print(f"GPU Memory Usage: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
        print(f"GPU Memory Cache: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")
    
    if result is not None:
        print(f"Peak Memory: {result.peak_memory_usage/1024**2:.2f}MB")
        print(f"Inference Time: {result.time.total_seconds():.2f}s")

def clear_gpu_memory():
    gc.collect()
    torch.cuda.empty_cache()

def extract_answer(text):
    try:
        last_line = text.strip().split('\n')[-1].strip()
        
        if last_line.startswith("The answer is "):
            answer = last_line[14]
            if answer in "ABCDE":
                return answer
            
        print(f"[DEBUG] Last line: {last_line}")
        return None
        
    except Exception as e:
        print(f"Error extracting answer: {text[:100]}... | {str(e)}")
        return None

def normalize_answer(answer):
    if not answer:
        return None
    answer = answer.strip().upper()
    if answer in string.ascii_uppercase[:5]:
        return answer
    return None

def load_qa_dataset(file_path):
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            data = f.read().split('\n\n')
        
        dataset_dict = {
            "prompt": [],
            "completion": [],
            "original_text": [],
            "correct_answer": [] 
        }
        
        system_prompts = {
            "instruction": """    .""",
                
            "examples": """Example 1:

Toxic Example 2:
"""
        }
        
        for qa in data:
            if not qa.strip():
                continue
            
            try:
                correct_answer = extract_answer(qa)
                if not correct_answer:
                    print(f"Warning: Unable to extract answer from following text:\n{qa[:200]}...")
                    continue
                
                lines = qa.strip().split('\n')
                question = lines[0]
                completion = '\n'.join(lines[1:])
                
                dataset_dict["prompt"].append(f"{system_prompts['instruction']}\n\n{system_prompts['examples']}\n\n{question}\n")
                dataset_dict["completion"].append(completion)
                dataset_dict["original_text"].append(qa)
                dataset_dict["correct_answer"].append(correct_answer)
                
            except Exception as e:
                print(f"Error processing QA pair: {str(e)}")
                continue
        
        valid_count = 0
        for qa in data:
            if qa.strip():
                answer = extract_answer(qa)
                if answer:
                    valid_count += 1
                else:
                    print(f"Invalid answer format:\n{qa[-200:]}")
                    
        print(f"Total {len(data)} QA pairs, {valid_count} with valid answers")
        
        return Dataset.from_dict(dataset_dict)
    except Exception as e:
        print(f"Error loading dataset: {str(e)}")
        return None

def reward_func_outcome(completions, **kwargs):
    rewards = []
    reward_log_path = kwargs.get('reward_log_path')
    
    for idx, completion in enumerate(completions):
        reward = 0.0
        try:
            format_check = check_format(completion)
            if format_check['basic_format']:
                reward += 0.2
            if format_check['step_quality']:
                reward += 0.1
                
            if format_check['step_count'] >= 3:
                reward += 0.1
            
            correct_answers = kwargs.get('correct_answers', [])
            correct_answer = correct_answers[idx] if idx < len(correct_answers) else None
            model_answer = extract_answer_from_completion(completion)
            
            if correct_answer and model_answer and model_answer == correct_answer:
                reward += 0.5
            
            if reward_log_path and os.path.exists(os.path.dirname(reward_log_path)):
                try:
                    with open(reward_log_path, 'a', encoding='utf-8') as f:
                        csv_writer = csv.writer(f)
                        csv_writer.writerow([
                            kwargs.get('step', 0),
                            idx,
                            model_answer or 'UNKNOWN',
                            correct_answer or 'UNKNOWN', 
                            format_check['basic_format'],
                            len(re.findall(r'\d+\.', completion)),
                            f"{reward:.2f}"
                        ])
                except Exception as e:
                    logger.error(f"Error writing log: {str(e)}")
            
        except Exception as e:
            logger.error(f"Error calculating reward: {str(e)}")
            reward = 0.0
            
        rewards.append(reward)
    return rewards

def reward_func_process(prompts, completions, **kwargs):
    reward_log_path = kwargs.get('reward_log_path', None)
    step = kwargs.get('step', 0)
    rewards = []
    
    for idx, (prompt, completion) in enumerate(zip(prompts, completions)):
        reward = 0.0
        try:
            contains_suspicious = "@_@" in prompt
            
            if contains_suspicious:
                suspect_ok = "<suspect>" in completion
                harm_ok = "<harm>" in completion and "</harm>" in completion
                
                if suspect_ok:
                    reward += 0.5
                if harm_ok:
                    reward += 1.0
                    
                if harm_ok:
                    harm_content = re.findall(r'<harm>(.*?)</harm>', completion, re.DOTALL)
                    if harm_content and len(harm_content[0].strip()) > 20:
                        reward += 0.3
                        if any(word in harm_content[0].lower() for word in ['danger', 'risk', 'harm', 'warning', 'caution']):
                            reward += 0.3
            
            if reward_log_path:
                with open(reward_log_path, 'a', encoding='utf-8') as f:
                    csv_writer = csv.writer(f)
                    csv_writer.writerow([
                        f"process_{step}",
                        idx,
                        suspect_ok,
                        harm_ok,
                        f"{reward:.2f}"
                    ])
                    
        except Exception as e:
            logger.error(f"Error processing reward: {str(e)}")
            reward = 0.0
            
        rewards.append(reward)
        
    return rewards

def check_format(text):
    try:
        lines = text.strip().split('\n')
        
        has_question = lines[0].startswith("Question:")
        
        has_choices = any("Answer Choices: A:" in line for line in lines)
        
        has_think_step = any("Let's think step by step." in line for line in lines)
        steps = re.findall(r'\d+\..*', text)
        has_numbered_steps = len(steps) >= 3
        steps_coherent = all(str(i+1)+'.' in text for i in range(len(steps)))
        
        has_answer = bool(re.search(r'The answer is [A-E]\.', text))
        
        return {
            'basic_format': has_question and has_choices and has_think_step and has_answer,
            'step_quality': has_numbered_steps and steps_coherent,
            'step_count': len(steps)
        }
    except Exception as e:
        print(f"Error in format check: {str(e)}")
        return {'basic_format': False, 'step_quality': False, 'step_count': 0}

def prepare_model_for_training(model):
    model.gradient_checkpointing_enable()
    model.config.use_cache = False
    
    for param in model.parameters():
        param.requires_grad = True
    
    if hasattr(model, 'base_model'):
        for n, p in model.base_model.named_parameters():
            if any(layer in n for layer in ['lora', 'adapter']):
                p.requires_grad = True
    
    model.enable_input_require_grads()
    return model

def setup_logger(output_dir):
    log_dir = os.path.join(output_dir, "logs")
    os.makedirs(log_dir, exist_ok=True)
    
    current_time = datetime.now().strftime('%Y%m%d_%H%M%S')
    date_dir = os.path.join(log_dir, datetime.now().strftime('%Y%m%d'))
    os.makedirs(date_dir, exist_ok=True)
    
    log_file = os.path.join(date_dir, f'training_{current_time}.log')
    
    progress_dir = os.path.join(log_dir, 'progress')
    os.makedirs(progress_dir, exist_ok=True)
    progress_file = os.path.join(progress_dir, f'training_progress_{current_time}.csv')
    
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s [%(levelname)s] %(message)s',
        handlers=[
            logging.FileHandler(log_file),
            logging.StreamHandler()
        ]
    )
    
    with open(progress_file, 'w', encoding='utf-8') as f:
        f.write("epoch,step,loss,reward_outcome,reward_process,total_reward,learning_rate\n")
    
    reward_log_dir = os.path.join(output_dir, "reward_logs")
    os.makedirs(reward_log_dir, exist_ok=True)
    reward_log_file = os.path.join(reward_log_dir, 
                                  f'reward_log_{datetime.now().strftime("%Y%m%d_%H%M%S")}.csv')
    
    with open(reward_log_file, 'w', encoding='utf-8') as f:
        f.write("step,index,model_answer,correct_answer,format_ok,steps,reward\n")
    
    return logging.getLogger(), progress_file, reward_log_file

def main():
    qa_file = "/grpo_meterial/csqa/mixed_csqa_data_100+300.txt"
    model_path = "/models/DeepSeek-R1-Distill-Llama-8B"
    output_dir = "/models/TP-ds-llama-8B-400-csqa-2"

    dataset = load_qa_dataset(qa_file)
    if dataset is None or len(dataset) == 0:
        print("Dataset loading failed or empty")
        return

    print(f"Dataset size: {len(dataset)}")

    print("Initial GPU memory status:")
    print_gpu_memory()
    clear_gpu_memory()

    clear_gpu_memory()
    torch.cuda.empty_cache()
    
    torch.cuda.set_device(0)
    torch.cuda.empty_cache()
    torch.backends.cudnn.benchmark = True
    
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        torch_dtype=torch.float16,
        trust_remote_code=True
    )

    model = model.to("cuda")

    for param in model.parameters():
        if param.device == torch.device("meta"):
            raise ValueError("Model parameters are still on meta device, please check loading process.")

    print("\nGPU memory status after model loading:")
    print_gpu_memory()

    model.config.use_cache = False
    model.config.pad_token_id = model.config.eos_token_id
    
    lora_config = LoraConfig(
        r=16,
        lora_alpha=32,
        target_modules=[
            "q_proj",
            "k_proj",
            "v_proj",
            "o_proj",
            "gate_proj",
            "up_proj",
            "down_proj"
        ],
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM",
        inference_mode=False
    )

    model = get_peft_model(model, lora_config)
    model.train()

    tokenizer = AutoTokenizer.from_pretrained(model_path)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = 'left'

    generation_config = GenerationConfig(
        max_new_tokens=1024,
        do_sample=True,
        temperature=0.9,
        top_p=0.9,
        pad_token_id=model.config.pad_token_id,
        eos_token_id=model.config.eos_token_id,
        num_return_sequences=2,
        repetition_penalty=1.3,
        use_cache=False,
        return_dict_in_generate=True,
        output_scores=True
    )
    
    model.generation_config = generation_config
    
    training_args = GRPOConfig(
        per_device_train_batch_size=2,
        num_generations=2,
        gradient_checkpointing=True,
        warmup_ratio=0.15,
        warmup_steps=100,
        weight_decay=0.01,
        remove_unused_columns=False,
        push_to_hub=False,
        torch_compile=False,
        gradient_accumulation_steps=4,
        learning_rate=5e-5,
        num_train_epochs=5,
        report_to=[],
        output_dir=output_dir,
        logging_steps=10,
        save_steps=500,
        scale_rewards=True,
        fp16=True,
        max_grad_norm=0.3,
        disable_tqdm=False,
        optim="adamw_torch",
        lr_scheduler_type="cosine_with_restarts"
    )

    logger, progress_file, reward_log_file = setup_logger(output_dir)
    logger.info("Starting training process")
    logger.info(f"Dataset size: {len(dataset)}")
    
    logger.info("Initial GPU memory status:")
    logger.info(f"GPU Memory Usage: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
    logger.info(f"GPU Memory Cache: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")

    class LoggingCallback(TrainerCallback):
        def __init__(self, logger, progress_file):
            self.logger = logger
            self.progress_file = progress_file
            self.step = 0
        
        def on_init_end(self, args, state, control, **kwargs):
            self.logger.info("Training initialization completed")
            return control
        
        def on_log(self, args, state, control, logs=None, **kwargs):
            if logs is None:
                return
            
            outcome_reward = logs.get('rewards/<lambda>', 0)
            process_reward = logs.get('rewards/reward_func_process_with_log', 0)
            
            with open(self.progress_file, 'a', encoding='utf-8') as f:
                f.write(f"{logs.get('epoch', 0)},"
                       f"{self.step},"
                       f"{logs.get('loss', 0)},"
                       f"{outcome_reward},"  # Use correctly obtained outcome_reward
                       f"{process_reward},"  # Use correctly obtained process_reward
                       f"{logs.get('reward', 0)},"
                       f"{logs.get('learning_rate', 0)}\n")
            
            if self.step % 100 == 0:
                self.logger.info(
                    f"Step {self.step} | "
                    f"Loss: {logs.get('loss', 0):.4f} | "
                    f"Outcome Reward: {outcome_reward:.4f} | "  # Use correctly obtained outcome_reward
                    f"Process Reward: {process_reward:.4f} | "  # Use correctly obtained process_reward
                    f"Total Reward: {logs.get('reward', 0):.4f} | "
                    f"LR: {logs.get('learning_rate', 0):.6f}"
                )
            
            self.step += 1

    callbacks = [LoggingCallback(logger, progress_file)]
    
    def reward_func_outcome_with_log(completions, **kwargs):
        try:
            dataset = kwargs.get('dataset')
            batch_indices = kwargs.get('batch_indices', list(range(len(completions))))  # If no indices, use default values            
            logger.debug(f"Batch indices: {batch_indices}")
            logger.debug(f"Dataset available: {dataset is not None}")
            
            correct_answers = []
            if dataset is not None:
                for idx in batch_indices:
                    try:
                        if 0 <= idx < len(dataset):
                            answer = dataset[idx].get('correct_answer')
                            logger.debug(f"Found answer '{answer}' for index {idx}")
                            correct_answers.append(answer)
                        else:
                            logger.warning(f"Index {idx} out of range")
                            correct_answers.append(None)
                    except Exception as e:
                        logger.error(f"Error getting answer for index {idx}: {e}")
                        correct_answers.append(None)
            
            while len(correct_answers) < len(completions):
                correct_answers.append(None)
                
            logger.debug(f"Final answers: {correct_answers}")
            
            return reward_func_outcome(
                completions=completions,
                correct_answers=correct_answers,
                reward_log_path=kwargs.get('reward_log_path'),
                step=kwargs.get('step', 0)
            )
            
        except Exception as e:
            logger.error(f"reward_func_outcome_with_log error: {str(e)}")
            return [0.0] * len(completions)

    def reward_func_process_with_log(prompts, completions, **kwargs):
        return reward_func_process(prompts, completions, reward_log_path=reward_log_file, **kwargs)

    training_args.tokenizer = tokenizer
    
    trainer = GRPOTrainer(
        model=model,
        reward_funcs=[
            lambda completions, **kwargs: reward_func_outcome_with_log(
                completions,
                dataset=dataset,
                reward_log_path=reward_log_file,
                **kwargs
            ),
            reward_func_process_with_log
        ],
        args=training_args,
        train_dataset=dataset,
        callbacks=callbacks
    )

    if hasattr(model, "decoder"):
        model.decoder = None
    if hasattr(model, "encoder"):
        model.encoder = None
    
    clear_gpu_memory()

    trainer.model.generation_config = generation_config
    trainer.model.config.use_cache = False
    trainer.model.generate = generate_wrapper.__get__(trainer.model)
    
    logger.info("Checking model configuration...")
    logger.info(f"use_cache: {trainer.model.config.use_cache}")
    logger.info(f"device: {next(trainer.model.parameters()).device}")
    logger.info(f"Training data size: {len(dataset)}")
    logger.info(f"batch_size: {training_args.per_device_train_batch_size}")
    logger.info(f"num_generations: {training_args.num_generations}")
    
    if not all(key in dataset.features for key in ['prompt', 'completion', 'correct_answer']):
        raise ValueError("Dataset missing required fields")
        
    trainer.model.train()
    
    trainable = any(p.requires_grad for p in trainer.model.parameters())
    logger.info(f"Is model trainable: {trainable}")

    try:
        trainer.train()
        logger.info("Training completed")
    except Exception as e:
        logger.error(f"Error during training: {str(e)}")
        raise e

def generate_wrapper(self, input_ids=None, attention_mask=None, **kwargs):
    try:
        if input_ids is not None:
            input_ids = input_ids.to(self.device)
        if attention_mask is not None:
            attention_mask = attention_mask.to(self.device)
        
        was_training = self.training
        self.eval()
        
        outputs = self.base_model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **kwargs
        )
        
        if was_training:
            self.train()
            
        return outputs
    except Exception as e:
        print(f"Error during generation: {str(e)}")
        raise e

def extract_answer_from_completion(completion):
    try:
        match = re.search(r'[Tt]he answer is ([A-E])', completion)
        if match:
            return match.group(1).upper()
        
        patterns = [
            r'Answer: ([A-E])',
            r'answer: ([A-E])',
            r'[Cc]hoice ([A-E])',
            r'([A-E]) is correct'
        ]
        
        for pattern in patterns:
            match = re.search(pattern, completion)
            if match:
                return match.group(1).upper()
                
        return None
    except Exception:
        return None

if __name__ == "__main__":
    main()
