import argparse
import os
import random
import time
from pathlib import Path
import math

import psutil
import torch
import torch.nn.functional as F
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
import copy

os.environ["DS_SKIP_CUDA_CHECK"] = "1"

from accelerate import Accelerator, DeepSpeedPlugin
from accelerate.utils import set_seed
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
from torch.optim import AdamW


os.environ.setdefault("NCCL_TIMEOUT", "2700")
os.environ.setdefault("TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC", "2700")

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name', type=str, default='Qwen2.5-Math-7B', help='Model name')
    parser.add_argument('--model_path', type=str, default=None, help='Local model path')
    parser.add_argument('--train_data', type=str, default='', help='Training data file path')
    parser.add_argument('--save_root', type=str, default=None, help='Checkpoint save root directory')
    parser.add_argument('--effective_batch', type=int, default=64, help='Global batch size')
    parser.add_argument('--temperature', type=float, default=1, help='Temperature coefficient')
    parser.add_argument('--learning_rate', type=float, default=2e-5, help='Learning rate')
    parser.add_argument('--log_steps', type=int, default=1, help='Logging step interval')
    parser.add_argument('--save_steps', type=int, default=1, help='Checkpoint saving step interval')
    parser.add_argument('--max_steps', type=int, default=50, help='Maximum training steps')
    parser.add_argument('--sample_temp', type=float, default=1, help='Generation temperature parameter')
    parser.add_argument('--run_name', type=str, default=None, help='Experiment run name')
    parser.add_argument('--seed', type=int, default=15, help='Random seed')
    parser.add_argument('--max_length', type=int, default=2048, help='Random seed')
    parser.add_argument('--max_new_tokens', type=int, default=512, help='Random seed')
    parser.add_argument('--num_return_sequences', type=int, default=1, help='response')
    return parser.parse_args()

class FTDataset(Dataset):
    def __init__(self, rows): self.rows = rows
    def __len__(self): return len(self.rows)
    def __getitem__(self, idx): return self.rows[idx]

def custom_collate(batch):
    return {"input": [item["input"] for item in batch]}

# if not llama
def apply_chat_template(tokenizer, problem: str) -> str:
    return tokenizer.apply_chat_template(
        [{"role": "user", "content": problem}],
        tokenize=False, add_generation_prompt=True
    )

def apply_chat_template_llama(tokenizer, problem: str) -> str:
    
    return tokenizer.apply_chat_template(
        [
            {"role": "system", "content": "Please reason step by step, and put your final answer within \\boxed{}. "}, 
            {"role": "user", "content": problem}
        ],
        tokenize=False, add_generation_prompt=True
    )

def main():
    args = parse_args()
    set_seed(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    world_size = int(os.getenv("WORLD_SIZE", "1"))
    eff_bs = args.effective_batch
    accum_steps = max(1, eff_bs // world_size)
    temp = args.temperature
    lr = args.learning_rate

    save_root = args.save_root or (f"checkpoints/{args.model_name}/{args.run_name}" if args.run_name else f"checkpoints/{args.model_name}")
    ds_config = {
        "train_micro_batch_size_per_gpu": 1,
        "train_batch_size": eff_bs,
        "gradient_accumulation_steps": accum_steps,
        "bf16": {"enabled": True},
        "zero_optimization": {
                              "stage": 2, 
                              "offload_optimizer": {"device": "cpu"}, 
                              "offload_param": {"device": "none"}
                             },
        "gradient_clipping": 1.0,
    }
    
    accelerator = Accelerator(mixed_precision="bf16", 
                              gradient_accumulation_steps=accum_steps, 
                              deepspeed_plugin=DeepSpeedPlugin(hf_ds_config=ds_config))
    print = accelerator.print

    model_path = args.model_path
    config = AutoConfig.from_pretrained(model_path)
    config.use_cache = False
    model = AutoModelForCausalLM.from_pretrained(model_path, config=config)
    
    model.gradient_checkpointing_enable()
    
    tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="left")
    tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token
    
    df = load_dataset("openai/gsm8k", "main", split = "test")
    df = pd.DataFrame(df)
    
    train_data = [{"input": apply_chat_template_llama(tokenizer, p)} for p in df["question"].dropna().tolist()]
    
    train_loader = DataLoader(FTDataset(train_data), batch_size=1, shuffle=True, collate_fn=custom_collate)

    optimizer = AdamW(model.parameters(), lr=lr)
    
    model, optimizer, train_loader = accelerator.prepare(model, optimizer, train_loader)    
    
    print(f"Total training batches: {len(train_loader)}")
    
    prev_logits = None
    model.train()
    
    for param in model.parameters():
        param.requires_grad = True

    for step, batch in enumerate(train_loader, start=0):

        print(batch)
        
        if step > args.max_steps:
            print(f"Exceed max step {args.max_steps}, training stopped.")
            break
        
        with accelerator.accumulate(model):
            enc = tokenizer(
                batch["input"], 
                return_tensors="pt", 
                padding="longest", 
                truncation=True, 
                max_length=2048
            ).to(accelerator.device)
        
            input_ids = enc.input_ids
        
            with torch.no_grad():
                gen_ids = accelerator.unwrap_model(model).generate(
                    **enc, 
                    max_new_tokens=512, 
                    do_sample=True, 
                    top_p=0.95, 
                    temperature=args.sample_temp, 
                    synced_gpus=False, 
                    repetition_penalty=1.05,
                    pad_token_id=tokenizer.pad_token_id, 
                    num_return_sequences=args.num_return_sequences,
                    use_cache=False
                )
            
            print(tokenizer.decode(gen_ids[0].tolist(), skip_special_tokens=True))
            
            expanded_input_ids = input_ids.repeat_interleave(args.num_return_sequences, dim=0)
            gen_part = gen_ids[:, input_ids.shape[1]:]
            seq = torch.cat([expanded_input_ids, gen_part], dim=1)[:, :4096]
        
            pad_mask = seq.ne(tokenizer.pad_token_id)
            prompt_len = pad_mask[:, :expanded_input_ids.shape[1]].sum(-1)
            token_idx = torch.arange(seq.size(1), device=seq.device)
            gen_mask = (token_idx.unsqueeze(0) >= prompt_len.unsqueeze(1)) & pad_mask
            
            def get_seq_logprob(m, seq, pad_mask, gen_mask, temp):
                device = next(m.parameters()).device
                seq = seq.to(device)
                pad_mask = pad_mask.to(device)
                gen_mask = gen_mask.to(device)
            
                logits = m(seq, attention_mask=pad_mask).logits
                log_probs = torch.log_softmax(logits[:, :-1, :] / temp, dim=-1)
                target_tokens = seq[:, 1:]
                token_logprobs = torch.gather(log_probs, -1, target_tokens.unsqueeze(-1)).squeeze(-1)
                token_logprobs = token_logprobs * gen_mask[:, 1:]
                seq_logprob = token_logprobs.sum(dim=1)
            
                lengths = gen_mask[:, 1:].sum(dim=1).float().to(device)
                return seq_logprob, lengths
        
            logprob_new, lengths = get_seq_logprob(model, seq, pad_mask, gen_mask, temp)
            
            # Option 1
            
            alpha = 1
            prob_seq = torch.exp(logprob_new.detach())
            loss = - (prob_seq + alpha) * logprob_new
            
            
            '''
            Option 2 --- scale the confidence
            
            reward = torch.exp(logprob_new.detach() / lengths)
            
            loss = -(reward * logprob_new / lengths).mean()
            
            '''

            '''
            Option 3 --- initial confidence
            
            prob_seq = torch.exp(logprob_new.detach())
            
            loss = - prob_seq * logprob_new
            
            '''
            
            
            print("Loss:", loss.item())
            
            accelerator.backward(loss)
            accelerator.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            optimizer.zero_grad()
            
        if accelerator.is_main_process:
            if step % args.log_steps == 0:
                print(f"Step {step} | loss={loss.item():.6f}")
                
            if step % args.save_steps == 0:
                ckpt = Path(save_root) / f"step_{step}"
                ckpt.mkdir(parents=True, exist_ok=True)
                accelerator.unwrap_model(model).save_pretrained(ckpt, safe_serialization=True)
                tokenizer.save_pretrained(ckpt)
                print(f"Checkpoint saved to {ckpt}")
                
    if accelerator.is_main_process:
        final = Path(save_root) / "final"
        final.mkdir(parents=True, exist_ok=True)
        accelerator.unwrap_model(model).save_pretrained(final, safe_serialization=True)
        tokenizer.save_pretrained(final)
        print(f"Final checkpoint saved to {final}")

if __name__ == "__main__":
    main()