import json
import sys
import ast
import time
import matplotlib.pyplot as plt
import pandas as pd
import yaml
from pathlib import Path 
from datetime import timedelta
import os
from datasets import Dataset, DatasetDict,load_dataset
from transformers import TrainingArguments, Trainer, DataCollatorForLanguageModeling
from datetime import datetime
import re
import gc
from langchain_community.chat_models import ChatOpenAI
import numpy as np
import torch
import torch.nn.functional as F
from datasets import concatenate_datasets, load_dataset
from peft import get_peft_model, LoraConfig, prepare_model_for_kbit_training
from torch.utils.data import DataLoader
from tqdm import tqdm
from rag_agent import FeedbackRAGAgent
from code_agent import CodeAgent
from torch.utils.data import DataLoader
from pydantic import BaseModel
from openai import OpenAI
from loaders import load_model_distributed_ddp, load_model_distributed_flashattn, load_model_distributed_ddp_lora_fast
from peft import PeftModel
from huggingface_hub import login
hf_token = ""
login(token=hf_token)

# DDP setup
import argparse
from torch.distributed import init_process_group, destroy_process_group
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
import torch.multiprocessing as mp

def ddp_setup(rank, world_size, port):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = port
    init_process_group("nccl", rank=rank, world_size=world_size, timeout=timedelta(minutes=10))


def generate_prompt(sample, tokenizer):
    """
    Formats a single physics CoT sample into chat-style prompt.
    """
    prompt = [
    {
        "role": "system",
        "content": (
            "You are a physics expert assistant. "
            "You are given a physics problem. Your task is to solve it in a clear and rigorous step-by-step format. "
            "Each step should be clearly numbered.\n"
            "Strictly follow this format:\n"
            "## Step 1: <step_1>\n"
            "## Step 2: <step_2>\n"
            "... and so on until the final answer is derived.\n\n"
            )
    },
    {
        "role": "user",
        "content": sample.get("question", "")
    }
]
    model_prompt = tokenizer.apply_chat_template(prompt, tokenize=False)
    return model_prompt
    
    
def format_example(sample,tokenizer):
    formatted_prompt = generate_prompt(sample, tokenizer)
    # return tokenizer(formatted_prompt, truncation=True, padding="max_length", max_length=1024)
    sample['prompt'] = formatted_prompt
    return sample
    

feedback_prompt = """There is a {error_type} in Step {error_step}.
The mistake occurred because: {error_explanation}.
Please revise the complete solution by correcting this step and generate the complete correct solution.\n
Strictly follow this format:\n## Step 1: <step_1>\n## Step 2: <step_2>\n... and so on until the final answer is derived.\n\n"""

feeback_prompt_no_error = """There is a no error in your solution."""

eval_prompt = '''
You are a physics reasoning evaluator. Your task is to identify and explain the first reasoning error (if any) made by a language model in its step-by-step solution to a physics problem.

You will be provided with:
(1) A physics question.
(2) A correct step-by-step solution (ground truth).
(3) An step-by-step solution generated by a language model.

Please do the following:
1. Identify the first error step (if any) in the model's solution where the reasoning deviates from the correct solution or contains an error.
2. Classify the error into one of the following categories:
   - "Problem Miscomprehension": 
        This error occurs before any physics principles or formulas are applied. It involves a misunderstanding of the problem statement itself:
        Misinterpreting given values or conditions (e.g., using M instead of 3M, treating “frictionless” as if friction is present).
   - "Conceptual Error":
        A conceptual error is a mistake that arises from applying the wrong physics law, formula, or principle to a problem. It involves incorrect reasoning due to a misunderstanding of the underlying physics concepts, rather than a simple calculation or arithmetic mistake.
   - "Computational Error":
        Arithmetic operations (addition, subtraction, multiplication, division)
        Calculus operations (differentiation, integration, limits)
        Handling of exponents, roots, fractions, radicals
        Reasonable Rounding is allowed.
3. Provide a explanation of the error and the correction required.

Note:
     - Do not penalize for uncessary steps or correct unit coversions. 
     - Different LLM reasons in different ways from Ground Truth CoT, what matters is problem is correctly understood, concepts are correct, computations are correct and final answer is correct even if correctly converted into different units.
     - Final answer should be in reasonable error limit compared to ground truth CoT, rounding off is allowed.

Return your output in the following JSON format:

{{
  "error_step": <first error step number>,
  "error_type": "<one of the three categories>",
  "error_explanation": "<explanation of the error in the step and what the correct approach should be>"
}}


If no error and llm solution is correct return:

{{
  "error_step": 0,
  "error_type": "no_error",
  "error_explanation": "There is a no error in your solution."
}}

Note: When generating the error_explanation, ensure the output is valid JSON by following these requirements:
- Use only plain text with no special formatting
- Avoid backslashes, quotes within the text, or LaTeX symbols (e.g., $, ^, _)
- Do not include line breaks, tabs, or other control characters within string values
- Keep all text on single lines without manual line wrapping
- Use simple punctuation and avoid complex mathematical notation
- If mathematical expressions are needed, write them in plain text (e.g., "x squared" instead of "x^2")
- Ensure the explanation reads as continuous prose without embedded formatting

### Input:

Question: {question}
Ground Truth COT: {cot_solution}
LLM Generated Solution: {error_solution}

'''


def generate_feedback(decoded, question):
    try : 
        if decoded['error_step'] == 0:
            feedback = feeback_prompt_no_error
            return feedback

        if decoded['error_type'] == 'Problem Miscomprehension':
            feedback = feedback_prompt.format(
                error_step=decoded['error_step'],
                error_type=decoded['error_type'],
                error_explanation=decoded['error_explanation']
            )
        elif decoded['error_type'] == 'Conceptual Error':
            agent = FeedbackRAGAgent(pdf_path="physics-formulas-for-neet-2023.pdf", chroma_path="./chroma_db_feedback", embedding_model_name="sentence-transformers/all-MiniLM-L6-v2", llm_model_name="meta-llama/Llama-3.2-3B-Instruct", llm_api_key="your_nebius_key_here", llm_api_base="https://api.studio.nebius.com/v1/")
            feedback = agent.run_feedback_cycle(
                            question=question,
                            error_step=decoded["error_step"],
                            error_explanation=decoded["error_explanation"]
                        )
        elif decoded['error_type'] == 'Computational Error':
            agent = CodeAgent(question, decoded['error_step'], decoded['error_explanation'], llm_model_name="meta-llama/Llama-3.2-3B-Instruct", llm_api_key="your_nebius_key_here", llm_api_base="https://api.studio.nebius.com/v1/")
            feedback = agent.run()

        return feedback
    except :
        print("\nError decoding the output. Please check the model output format.")
        return decoded  # Return the raw output if parsing fails

def generate_reward(sample, eval_model):
    
    decoded = eval_model.invoke(eval_prompt.format(question=sample['question'], cot_solution = sample['cot_solution'], error_solution = sample['incorrect_solution'])).content
    try : 
        # match = re.search(r"json\s*(\{.*?\})\s*", decoded, re.DOTALL)
        # json_str = match.group(1)
        decoded = ast.literal_eval(decoded)
        # decoded = json.loads(decoded)
        error_step = decoded['error_step']
        if error_step == 0:
            reward = 1.0  # If no error, reward is 1.0
        else:
            step_pattern = r"(?:#+\s*)?Step\s+\d+:"
            matches = re.findall(step_pattern, sample['incorrect_solution'])
            if matches:
                total_steps = len(matches)
                reward = (error_step / (total_steps + 1))  # Reward is proportional to the step where the first error occurs
            else:
                reward = 0.01
                print("\nNo steps found in the incorrect solution. Assigning minimal reward.")  
        return reward,decoded
    
    except Exception as e:
        print("\nError decoding the output. Please check the model output format.")
        print(e)
        print(decoded)
        # print(f"Decoded output: {decoded}")
        reward = 0.01  # Assign a minimal reward if decoding fails
        return reward,decoded  # Return the raw output if parsing fails
        
def eval_answers(answers, solutions, questions, eval_model):
    rewards = []
    decoded_batch = []

    for a, s, q in zip(answers, solutions, questions):
        sample  = {
            "question": q,
            "cot_solution": s,
            "incorrect_solution": a
        }
        reward,decoded = generate_reward(sample, eval_model)
        rewards.append(reward)
        decoded_batch.append(decoded)
    return rewards, decoded_batch
        
@torch.no_grad()
def get_log_probs_incremental_batched(model, input_ids, generated_ids):
    """
    Efficiently computes logprobs for generated_ids given input_ids using KV caching (batched).
    Computes log p(generated_ids[i] | context + generated_ids[:i]) for i > 0
    """
    model.eval()
    B, T_gen = generated_ids.shape

    with torch.no_grad():
        # Step 1: Run full context to build KV cache
        with torch.inference_mode():
            out = model(input_ids=input_ids, use_cache=True)
            past_key_values = out.past_key_values

        # Step 2: Feed all generated_ids except last token (teacher forcing)
        input_ids_gen = generated_ids[:, :-1]
        targets_gen = generated_ids[:, 1:]

        out = model(
            input_ids=input_ids_gen,
            past_key_values=past_key_values,
            use_cache=False,
        )
        logits = out.logits  # [B, T_gen-1, vocab]
        log_probs = F.log_softmax(logits, dim=-1)
        logprobs = torch.gather(log_probs, dim=-1, index=targets_gen.unsqueeze(-1)).squeeze(-1)  # [B, T_gen-1]

        return logprobs
    
def get_log_probs(model, input_ids, prompt_len, return_probs=False):
    logits = model(input_ids).logits
    logits = logits[:, :-1, :]  # shift the logits to the right by one token
    input_ids = input_ids[:, 1:]  # remove the first token
    # Compute the log probabilities for the input tokens. Use a loop to reduce memory peak.
    per_token_logps = []
    if return_probs:
        log_probs_ = []
    for logits_row, input_ids_row in zip(logits, input_ids):
        log_probs = logits_row.log_softmax(dim=-1)
        if return_probs:
            log_probs_.append(log_probs)
        token_log_prob = torch.gather(
            log_probs, dim=1, index=input_ids_row.unsqueeze(1)
        ).squeeze(1)
        per_token_logps.append(token_log_prob)
    # stack and take off prompt token, -1 for remove by shifting logits
    if not return_probs:
        return torch.stack(per_token_logps)[:, prompt_len - 1 :]
    else:
        return (
            torch.stack(per_token_logps)[:, prompt_len - 1 :],
            torch.stack(log_probs_)[:, prompt_len - 1 :],
        )


def get_eos_mask(answer_ids, tokenizer):
    # gets where the eos token is in the answer_ids
    is_eos = answer_ids == tokenizer.eos_token_id
    # Set EOS index to first EOS position if found, otherwise use max sequence length
    # eos_idx = is_eos.int().argmax(dim=1) if is_eos.any() else torch.full((is_eos.size(0),), is_eos.size(1), device=is_eos.device)
    # Create mask using cumulative sum: 1s before first EOS, 0s after
    mask = (torch.cumsum(is_eos, dim=1) <= 1).int()
    return mask

def eval_stage(
    model,
    tokenizer,
    eval_dataset,
    BATCH_SIZE,
    eval_model,
    result_folder: str,
    epoch: int
):
    # build a dataloader over the full eval set
    eval_loader = DataLoader(
        eval_dataset,
        sampler=DistributedSampler(eval_dataset),
        batch_size=BATCH_SIZE,
        pin_memory=True,
    )

    eval_logs = []
    model.eval()
    with torch.no_grad():
        for batch in tqdm(eval_loader, desc=f"Eval Epoch {epoch+1}"):
            # 1) build prompt batch
            prompts = batch["prompt"]
            inputs = tokenizer(
                prompts, return_tensors="pt", padding=True, truncation=True,
                max_length=1024
            ).to(model.device)

            # 2) generate
            out = model.module.generate(
                input_ids=inputs["input_ids"],
                attention_mask=inputs["attention_mask"],
                max_new_tokens=2048,
                temperature=0.7
            )
            answers = tokenizer.batch_decode(out, skip_special_tokens=True)

            # 3) reward
            rewards, _ = eval_answers(
                answers,
                batch["cot_solution"],
                batch["question"],
                eval_model,
            )

            # 4) accumulate logs
            for q, a, r in zip(batch["question"], answers, rewards):
                eval_logs.append({"question": q, "answer": a, "reward": r})

    
    os.makedirs(result_folder, exist_ok=True)
    eval_file_path = os.path.join(result_folder, f"eval_logs_{epoch}_.json")
    with open(eval_file_path, "w") as f:
        json.dump(eval_logs, f, indent=2)
    print(f"[Epoch {epoch+1}] saved evaluation to {eval_file_path}")

    model.train()

def get_kl_div(base_logprobs, logprobs, mask):
    kl_div = F.kl_div(logprobs, base_logprobs, reduction="none", log_target=True)
    return (kl_div.mean(-1) * mask).sum(-1) / mask.sum(-1)

logs = []

def train_stage_1(
    model,
    tokenizer,
    train_dataset,
    eval_dataset,
    num_epochs,
    LEARNING_RATE,
    beta2,
    BATCH_SIZE=6,
    num_workers=2,
    prefetch_factor=4,
    MAX_PROMPT_LEN1=1024,
    MAX_PROMPT_LEN2=1024,
    mnt_attempt1=128,
    mnt_attempt2=128,
    TEMP=1.0,
    eval_model=None,
    save_dir=None
):
    dataloader = DataLoader(train_dataset, 
                            batch_size=BATCH_SIZE,
                            num_workers=num_workers,
                            pin_memory=True,
                            prefetch_factor=prefetch_factor,
                            sampler=DistributedSampler(train_dataset))
    
    optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()),lr=LEARNING_RATE)
    result_folder = Path(save_dir, "RESULTS")
    if not os.path.exists(result_folder):
        result_folder.mkdir(parents=True, exist_ok=True)

    for epoch in range(num_epochs):
        epoch_loss = []
        count_high_reward_1 = 0
        count_high_reward_2 = 0
        
        correct_solution1 = []
        correct_solution2 = []

        for i, batch in tqdm(enumerate(dataloader), desc=f"Stage 1 Training {epoch+1}/{num_epochs}"):
            
            mean_reward_a1 = []
            mean_reward_a2 = []
            device = model.device

            prompt_texts = batch["prompt"]
            solutions = batch["cot_solution"]
            questions = batch["question"]

            # Tokenize first attempt inputs
            x1 = tokenizer.batch_encode_plus(
                prompt_texts,
                padding="max_length",
                truncation=True,
                max_length=MAX_PROMPT_LEN1,
                return_tensors="pt",
            )
            x1 = {k: v.to(device) for k, v in x1.items()}

            #print('\n', x1.keys())
            #
            #print("Input IDs shape:", x1['input_ids'].shape)
            #print(x1['input_ids'].shape)
            # Generate attempt 1 (on-policy)
            with torch.no_grad():
                action1_token = model.module.generate(
                    x1["input_ids"],
                    attention_mask=x1["attention_mask"],
                    max_new_tokens=mnt_attempt1,
                    temperature=TEMP,
                )

            #print('\n', action1_token.shape)

            x1_len = x1["input_ids"].shape[1]
            attempt1_answer_tokens = action1_token[:, x1_len:]
            attempt1_answer_mask = get_eos_mask(attempt1_answer_tokens, tokenizer)

            answer1 = tokenizer.batch_decode(attempt1_answer_tokens, skip_special_tokens=True)
            reward1,decoded_batch = eval_answers(answer1, solutions, questions, eval_model)
            
            for r1 in reward1:
                if r1 > 0.95:
                    count_high_reward_1 += 1
            # if model.device.index == 0:
            #     print("Answer 1:", len(answer1))

            # Filter indices where reward1 <= 0.95
            filtered_indices = [idx for idx, r in enumerate(reward1) if r <= 0.95 and r != 0.01]

            #count_high_reward_1 += BATCH_SIZE - len(filtered_indices)

            # Skip entire batch if nothing remains
            if len(filtered_indices) == 0:
                continue

            # Filter all inputs accordingly
            prompt_texts = [prompt_texts[idx] for idx in filtered_indices]
            solutions = [solutions[idx] for idx in filtered_indices]
            questions = [questions[idx] for idx in filtered_indices]
            answer1 = [answer1[idx] for idx in filtered_indices]
            reward1 = [reward1[idx] for idx in filtered_indices]
            action1_token = action1_token[filtered_indices]
            attempt1_answer_mask = attempt1_answer_mask[filtered_indices]
            attempt1_answer_tokens = attempt1_answer_tokens[filtered_indices]
            x1["input_ids"] = x1["input_ids"][filtered_indices]
            mean_reward_a1.append(np.mean(reward1))

            # print("== Attempt 1 ==")
            # print(answer1[0])

            # print("\n\nReward 1:", reward1[0])
            

            # print("----------")
            ## To add feedback for self_correction using Llama API
            feedbacks = []
            for d, q in zip(decoded_batch, questions):
                feedback = generate_feedback(d, q)
                feedbacks.append(feedback)

            # print("\n\n== Feedbacks ==\n\n")    
            # print("Feedbacks:", feedbacks[0])

            messages_batch = []
            for pt, ans, feedback in zip(prompt_texts, answer1, feedbacks):
                messages = tokenizer.apply_chat_template(
                    [
                        {"role": "user", "content": pt},
                        {"role": "assistant", "content": ans},
                        {"role": "user", "content": feedback}
                    ],
                    add_generation_prompt=True,
                    tokenize=False
                )
                messages_batch.append(messages)

            # print('\n', messages_batch[0])

            x2 = tokenizer.batch_encode_plus(
                messages_batch,
                padding="max_length",
                truncation=True,
                max_length=MAX_PROMPT_LEN2,
                return_tensors="pt",
            )
            x2 = {k: v.to(device) for k, v in x2.items()}

            #print('\n', x2.keys())
            #print("Input IDs shape:", x2['input_ids'].shape)
            # Generate attempt 2
            with torch.no_grad():
                action2_token = model.module.generate(
                    x2["input_ids"],
                    attention_mask=x2["attention_mask"],
                    max_new_tokens=mnt_attempt2,
                    temperature=TEMP,
                )

            #print('\n', action2_token.shape)

            x2_len = x2["input_ids"].shape[1]
            attempt2_answer_tokens = action2_token[:, x2_len:]
            attempt2_answer_mask = get_eos_mask(attempt2_answer_tokens, tokenizer)

            answer2 = tokenizer.batch_decode(attempt2_answer_tokens, skip_special_tokens=True)
            reward2, decoded_batch_2 = eval_answers(answer2, solutions, questions, eval_model)

            feedbacks2 = []
            for d, q in zip(decoded_batch_2, questions):
                feedback = generate_feedback(d, q)
                feedbacks2.append(feedback)

            filtered_indices = [idx for idx, r in enumerate(reward2) if r != 0.01]
            if len(filtered_indices) == 0:
                continue
            for r2 in reward2:
                if r2 > 0.95:
                    count_high_reward_2 += 1
            # Filter again
            prompt_texts = [prompt_texts[idx] for idx in filtered_indices]
            solutions = [solutions[idx] for idx in filtered_indices]
            questions = [questions[idx] for idx in filtered_indices]
            answer1 = [answer1[idx] for idx in filtered_indices]
            answer2 = [answer2[idx] for idx in filtered_indices]
            feedbacks = [feedbacks[idx] for idx in filtered_indices]
            reward1 = [reward1[idx] for idx in filtered_indices]
            reward2 = [reward2[idx] for idx in filtered_indices]
            feedbacks2 = [feedbacks2[idx] for idx in filtered_indices]
            action1_token = action1_token[filtered_indices]  # ✅ again, keep aligned
            action2_token = action2_token[filtered_indices]
            attempt1_answer_mask = attempt1_answer_mask[filtered_indices]
            attempt2_answer_mask = attempt2_answer_mask[filtered_indices]
            attempt1_answer_tokens = attempt1_answer_tokens[filtered_indices]
            attempt2_answer_tokens = attempt2_answer_tokens[filtered_indices]
            x1["input_ids"] = x1["input_ids"][filtered_indices]
            x2["input_ids"] = x2["input_ids"][filtered_indices]
            mean_reward_a2.append(np.mean(reward2))

            # print("\n\n== Attempt 2 ==\n\n")
            # print(answer2[0])

            # print("\n\nReward 2:", reward2[0])
            # print("\n\nFeedback 2:", feedbacks2[0])

            # Compute RL loss
            with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
                action1_token = action1_token.clone()
                with torch.no_grad():
                    try:
                        # PEFT/LoRA adapter disable for KL computation
                        with model.disable_adapter():
                            _, base_probs = get_log_probs(model, action1_token, x1_len, return_probs=True)
                    except AttributeError:
                        # For models without PEFT
                        _, base_probs = get_log_probs(model, action1_token, x1_len, return_probs=True)

                att1_log_probs, att1_probs = get_log_probs(model, action1_token, x1_len, return_probs=True)
                kl_div = get_kl_div(base_probs, att1_probs.detach(), attempt1_answer_mask)

                action2_token = action2_token.clone()
                # att2_log_probs = get_log_probs(model, action2_token, x2_len)
                att2_log_probs = get_log_probs_incremental_batched(
                    model = model,              # unwrap from DDP
                    input_ids=x2["input_ids"],                  # prompt + a1 + feedback
                    generated_ids=attempt2_answer_tokens,       # only the generated action2 tokens
                )

                reward2_tensor = torch.tensor(reward2, device=device, dtype=torch.bfloat16)
                # loss = -(
                #     (
                #         (att1_log_probs * attempt1_answer_mask).sum(-1) / attempt1_answer_mask.sum(-1)
                #         + (att2_log_probs * attempt2_answer_mask).sum(-1) / attempt2_answer_mask.sum(-1)
                #     ) * reward2_tensor - beta2 * kl_div
                # ).mean()

                loss = -(
                    (
                        (att2_log_probs * attempt2_answer_mask[:, 1:]).sum(-1) / attempt2_answer_mask[:, 1:].sum(-1)
                        - (att1_log_probs[:, 1:] * attempt1_answer_mask[:, 1:]).sum(-1) / attempt1_answer_mask[:, 1:].sum(-1)
                    ) * reward2_tensor - beta2 * kl_div
                ).mean()

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                epoch_loss.append(loss.item())
                # correct_solution1.append(correct1)
                # correct_solution2.append(correct2)

            if torch.distributed.get_rank() == 0:
                # Logging
                print(
                    f"mean_reward_attmpt1: {np.mean(mean_reward_a1):.4f}, "
                    f"mean_reward_attmpt2: {np.mean(mean_reward_a2):.4f}, "
                    f"difference_at1_at2: {np.mean(mean_reward_a2) - np.mean(mean_reward_a1):.4f}, "
                    f"loss: {np.mean(epoch_loss):.4f}, "
                    f"kl_div: {kl_div.mean().item():.4f}"
                )
            
            for q, a1, r1, f, a2, r2, f2 in zip(questions, answer1, reward1, feedbacks, answer2, reward2, feedbacks2):
                log = {}
                log['question'] = q
                log['answer1'] = a1
                log['reward1'] = r1
                log['feedback'] = f
                log['answer2'] = a2
                log['reward2'] = r2
                log['feedback2'] = f2
                log['epoch'] = epoch + 1
                logs.append(log)
            
            with open(os.path.join(result_folder,f'logs_{epoch}_{torch.distributed.get_rank()}.json'), 'w') as f:
                json.dump(logs, f, indent = 2)

            f.close()

            # === FREE MEMORY ===
            del x1, x2
            del action1_token, action2_token
            del attempt1_answer_tokens, attempt2_answer_tokens
            del att1_log_probs, att2_log_probs, att1_probs, base_probs
            del reward1, reward2
            gc.collect()
            torch.cuda.empty_cache()

            if torch.distributed.get_rank() == 0:
                print("Epoch {}: , Count High Reward 1: {}, Count High Reward 2: {}".format(
                    epoch + 1,
                    count_high_reward_1,
                    count_high_reward_2))
        
        # torch.distributed.barrier()
        if torch.distributed.get_rank() == 0:
            eval_stage(model,tokenizer,eval_dataset,BATCH_SIZE,eval_model,result_folder, epoch)
    
        # torch.distributed.barrier()
        if torch.distributed.get_rank() == 0:
            model_save_dir = Path(result_folder, f"Model_checkpoint_{epoch}")
            model_save_dir.mkdir(parents=True, exist_ok=True)
            model.module.save_pretrained(model_save_dir)
            tokenizer.save_pretrained(model_save_dir)
            print(f"Model and tokenizer saved to {model_save_dir}")

    # torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        model_save_dir = Path(result_folder,"Model_final")
        model_save_dir.mkdir(parents=True, exist_ok=True)
        model.module.save_pretrained(model_save_dir)
        tokenizer.save_pretrained(model_save_dir)
        print(f"Model and tokenizer saved to {model_save_dir}")

    torch.distributed.barrier()
    return model

def main(rank:int, world_size:int, port:str, gpu_per_model:int):
    # DDP setup
    ddp_setup(rank, world_size, port)

    device_start_idx = rank * gpu_per_model
    device_end_idx = device_start_idx + gpu_per_model
    device_list = list(range(device_start_idx, device_end_idx))
    print(f"Device list: {device_list}")

    time_l = datetime.now().strftime("%Y-%m-%d_%H")
    OUT_DIR = Path(f"")
    if not os.path.exists(OUT_DIR):
        OUT_DIR.mkdir(parents=True, exist_ok=True)
    # Load model and tokenizer
    model_name = ''

    os.environ['OPENAI_API_KEY'] = ""
    llm_model = ""

    save_dir = os.path.join(OUT_DIR, model_name.split('/')[-1])

    if not os.path.exists(save_dir):
        os.makedirs(save_dir, exist_ok=True)

    model, tokenizer = load_model_distributed_ddp_lora_fast(model_name, device_list)
    eval_model = ChatOpenAI(model=llm_model,)

    # DDP setup
    model_ddp = DDP(model,find_unused_parameters=True)
    # eval_model_ddp = DDP(eval_model)

    raw_dataset = load_dataset("json", data_files='',split="train")

    
    #raw_dataset = raw_dataset.select(range(6))
    
    
    formatted_dataset = raw_dataset.map(format_example, fn_kwargs={"tokenizer": tokenizer})
    
    # Split into train and eval
    dataset = formatted_dataset.train_test_split(test_size=0.2, seed=42)
    
    train_dataset = dataset["train"]
    eval_dataset = dataset["test"]
    
    #train_dataset = train_dataset.select(range(6))
    eval_dataset = eval_dataset.select(range(100))

    if torch.distributed.get_rank() == 0:
        print(len(train_dataset))
        print(len(eval_dataset))
    
    # Hyperparameters
    BATCH_SIZE = 5
    NUM_WORKERS = 2
    PREFETCH_FACTOR = 4
    LEARNING_RATE = 5e-6
    MAX_PROMPT_LEN1 = 512
    mnt_attempt1 = 1240
    mnt_attempt2 = 1240
    MAX_PROMPT_LEN2 = 2480
    stage_1_epochs = 10 # 10 epochs for stage 1 training
    stage_2_epochs = 60
    BETA1 = 0.01
    BETA2 = 0.1
    ALPHA = 10  # 𝛼 is a positive constant multiplier, ideally larger than 1.0
    TEMP = 1.0
    
    start_time = time.time()
    
    train_stage_1(
    model_ddp,
    tokenizer,
    train_dataset,
    eval_dataset,
    stage_1_epochs,
    LEARNING_RATE,
    BETA2,
    BATCH_SIZE,
    NUM_WORKERS,
    PREFETCH_FACTOR,
    MAX_PROMPT_LEN1,
    MAX_PROMPT_LEN2,
    mnt_attempt1,
    mnt_attempt2,
    TEMP,
    eval_model,
    save_dir
   )
    end_time = time.time()
    print(f"Time used: {end_time-start_time}",)
    destroy_process_group()



if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--rank", type=int, default=0)
    parser.add_argument("--world_size", type=int, default=1)
    parser.add_argument("--port", type=str, default="12355")
    parser.add_argument("--gpu_per_model", type=int, default=1)
    args = parser.parse_args()
    world_size = args.world_size // args.gpu_per_model
    mp.spawn(main, args=(world_size, args.port, args.gpu_per_model), nprocs=world_size, join=True)