import random
import numpy as np

import torch
torch.set_float32_matmul_precision('high')
import torch.nn.functional as F
from torch.optim import AdamW, RMSprop
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau, LambdaLR
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, GenerationConfig
from functools import partial
from tqdm import tqdm
from solver import solve_weighting_vector_scipy
from peft import LoraConfig, get_peft_model

import time
import os
import argparse

import sys
from contextlib import nullcontext

from lora_configs import lora_configs
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from utils.utils import print_color, CheckpointManager, create_logger, ensure_and_clean_directory, create_directory
from data_prep.pre_processing_data import load_and_process_dataset_from_name 
from metrics.perplexity import compute_perplexity_ratio


def seed_everything(seed=0):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True


def dpo_loss(policy_chosen_logps, policy_rejected_logps,
             reference_chosen_logps, reference_rejected_logps,
             beta=0.5, loss_type="sigmoid"):
    chosen_logratios = policy_chosen_logps - reference_chosen_logps
    rejected_logratios = policy_rejected_logps - reference_rejected_logps
    logits = chosen_logratios - rejected_logratios

    #print("logits is: ------")
   #print(logits)
    if loss_type == "sigmoid":
        losses = -F.logsigmoid(beta * logits)
    elif loss_type == "ipo":
        #losses = (logits - 1 / (2 * beta)) ** 2
        losses = (beta * logits - 1) ** 2
    elif loss_type == "SLIC":
        losses = torch.maximum(torch.tensor(0.0), 1 - beta * logits)
    else:
        raise NotImplementedError
    loss = losses.mean(dim=-1)

    reward_accuracies = (chosen_logratios > rejected_logratios).float().mean(dim=-1)
    reward_margins = beta*(chosen_logratios - rejected_logratios).mean(dim=-1)

    return loss, chosen_logratios.mean(dim=-1), rejected_logratios.mean(dim=-1), reward_accuracies, reward_margins


def get_log_prob(model, prompt_ids, prompt_mask, tokenizer, average_log_prob=False):
    logits = model(prompt_ids, attention_mask=prompt_mask).logits

    labels = prompt_ids[:, 1:].clone()
    labels_mask = prompt_mask[:, 1:].clone()
    logits_removed = logits[:, :-1, :].clone()
    log_probs = F.log_softmax(logits_removed, dim=-1)
    per_token_logs = torch.gather(log_probs, dim = 2, index = labels.unsqueeze(2)).squeeze(2)

    if average_log_prob:
        return (per_token_logs * labels_mask).sum(-1) / labels_mask.sum(-1)
    else:
        return (per_token_logs * labels_mask).sum(-1)


def get_log_prob_ref(model, prompt_ids, prompt_mask, tokenizer, average_log_prob=False):
    with model.disable_adapter():
        model.eval()
        result = get_log_prob(model, prompt_ids, prompt_mask, tokenizer, average_log_prob)
        model.train()
        return result

def collate_fn(batch, tokenizer, max_length, device):
    """
      Collate function for handlingsingle dimension data.

      Parameters:
          batch: List of samples in the batch.
          tokenizer: Tokenizer to encode the data.
          max_length: Maximum sequence length for pafding/truncation.
          device: Device to move tensors to..

      Returns:
          A dictionary containing batched inputs for each dimension.
      """
    # Organize data into dimensions
    batched_data = {}

    chosen_responses = [item['prompt'] + item['chosen']  for item in batch]
    rejected_responses = [item['prompt'] + item['rejected'] for item in batch]

    # Tokenize the inputs (prompt, preferred response, dispreferred response)
    preferred_encoding = tokenizer.batch_encode_plus(
        chosen_responses,
        padding=True,
        truncation=True,
        max_length=max_length,
        return_tensors="pt",
        return_attention_mask=True
    )

    dispreferred_encoding = tokenizer.batch_encode_plus(
        rejected_responses,
        padding=True,
        truncation=True,
        max_length=max_length,
        return_tensors="pt",
        return_attention_mask=True
    )

    # Extract token IDs and attention masks
    preferred_ids = preferred_encoding['input_ids'].to(device)
    preferred_attention_mask = preferred_encoding['attention_mask'].to(device)

    dispreferred_ids = dispreferred_encoding['input_ids'].to(device)
    dispreferred_attention_mask = dispreferred_encoding['attention_mask'].to(device)

    # old version ids and mask
    #prompt_ids = tokenizer.batch_encode_plus(prompts, padding=True, return_tensors="pt", max_length=max_length, truncation=True)['input_ids'].to(device)
    #prefered_ids = tokenizer.batch_encode_plus(chosen_responses, padding=True, return_tensors="pt", max_length=max_length, truncation=True)['input_ids'].to(device)
    #disprefered_ids = tokenizer.batch_encode_plus(rejected_responses, padding=True, return_tensors="pt", max_length=max_length, truncation=True)['input_ids'].to(device)

    #prompt_prefered_ids = torch.cat([prompt_ids, prefered_ids], dim=-1)
    #prompt_disprefered_ids = torch.cat([prompt_ids, disprefered_ids], dim=-1)

    #prompt_prefered_mask = torch.cat([torch.ones_like(prompt_ids), torch.zeros_like(prefered_ids)], dim=-1)
    #prompt_disprefered_mask = torch.cat([torch.ones_like(prompt_ids), torch.zeros_like(disprefered_ids)], dim=-1)
    batched_data[f'prompt_chosen_ids'] = preferred_ids
    batched_data[f'prompt_rejected_ids'] = dispreferred_ids    
    batched_data[f'prompt_chosen_mask'] = preferred_attention_mask
    batched_data[f'prompt_rejected_mask'] = dispreferred_attention_mask
    return batched_data


def flatten_gradients(model):
    # Collect all gradients into a single tensor
    grads = [param.grad.clone().view(-1) for param in model.parameters() if
             (param.requires_grad) and (param.grad is not None)]
    flat_grads = torch.cat(grads)
    if flat_grads.is_cuda:
        flat_grads = flat_grads.cpu()
    return flat_grads


def restore_flat_gradients(model, flat_grads):
    """
    Restore the flattened gradients back to the model's parameters.

    Parameters:
        model (torch.nn.Module): The model whose parameters' gradients are to be restored.
        flat_grads (torch.Tensor): Flattened gradients tensor.
    """
    with torch.no_grad():
        current_idx = 0
        for param in model.parameters():
            if (param.requires_grad) and (param.grad is not None):
                grad_shape = param.grad.shape
                grad_numel = param.grad.numel()
                param.grad.copy_(flat_grads[current_idx:current_idx + grad_numel].view(grad_shape))  # Number of elements in the tensor
                current_idx += grad_numel
        assert current_idx == flat_grads.numel(), "The number of elements in the gradient tensor does not match the model's parameters."


def count_parameters_requires_grad(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


# Function to fetch the next batch from each dimension of DataLoader independently 
def fetch_batches(dict_dataloaderst, iterators):
    """
    dict_dataloaderst: {dimension: dataloader for dimension}
    iterators: {dimension: iterators} 
    """
    # Initialize iterators for each DataLoader,
     
    batches = {}
    
    for dimension, iterator in iterators.items():
        try:
            # Try to get the next batch from the iterator
            batches[dimension] = next(iterator)
        except StopIteration:
            # If StopIteration is raised, reset the iterator
            iterators[dimension] = iter(dict_dataloaderst[dimension])
            # Fetch the next batch after resetting the iterator
            batches[dimension] = next(iterators[dimension])
    return batches

def get_iteration(dict_dataloaderst):
    dataloader_sizes = {dimension: len(dataloader) for dimension, dataloader in dict_dataloaderst.items()}
    return max(dataloader_sizes.values())

@torch.no_grad()
def estimate_validation_loss(model, ref_model, dict_val_dataloaderst, 
                             dimensions, training_config, tokenizer, device):
    model.eval()
    use_lora = training_config["use_lora"]
    if use_lora:
        assert ref_model is None, "Reference model must be None when using LoRA"
    else:
        ref_model.eval()
    # Initialize the iterators here.
    val_iterators = {dimension: iter(dataloader) for dimension, dataloader in dict_val_dataloaderst.items()}
    # get min size among all dimensions
    min_size = min(len(dataloader) for dataloader in dict_val_dataloaderst.values())
    min_iteration = min(training_config["val_data_size"] // training_config["val_batch_size"], min_size) 
    print(f"validation size is:{min_size}, {min_iteration}")
    num_dims = len(dimensions)

    val_loss_list = [[] for _ in range(num_dims)]
    reward_accuracy_list = np.zeros(num_dims)
    reward_margin_list = np.zeros(num_dims)

    acc_per, acc_neg, model_per, model_per_neg = np.zeros(num_dims), np.zeros(num_dims), np.zeros(num_dims), np.zeros(num_dims)
    with torch.no_grad():  # Disable gradient calculation
        for _ in tqdm(range(min_iteration), desc="Fetching evaluation batches"):
            batches = fetch_batches(dict_val_dataloaderst, val_iterators)
            for dim in range(num_dims):
                dimension = dimensions[dim]
                prompt_prefered_ids = batches[dimension][f'prompt_chosen_ids']
                prompt_disprefered_ids = batches[dimension][f'prompt_rejected_ids']
                prompt_prefered_mask = batches[dimension][f'prompt_chosen_mask']
                prompt_disprefered_mask = batches[dimension][f'prompt_rejected_mask']
                
                policy_chosen_logp = get_log_prob(model, prompt_prefered_ids, prompt_prefered_mask, tokenizer)
                policy_rejected_logp = get_log_prob(model, prompt_disprefered_ids, prompt_disprefered_mask, tokenizer)
                if use_lora:
                    reference_chosen_logp = get_log_prob_ref(model, prompt_prefered_ids, prompt_prefered_mask, tokenizer)
                    reference_rejected_logp = get_log_prob_ref(model, prompt_disprefered_ids, prompt_disprefered_mask, tokenizer)
                else:
                    reference_chosen_logp = get_log_prob(ref_model, prompt_prefered_ids, prompt_prefered_mask, tokenizer)
                    reference_rejected_logp = get_log_prob(ref_model, prompt_disprefered_ids, prompt_disprefered_mask, tokenizer)
            
                loss, prefered_relative_logprob, disprefered_relative_logprob, reward_accuracy, reward_margin = dpo_loss(policy_chosen_logp,
                                                                                                                        policy_rejected_logp,
                                                                                                                        reference_chosen_logp,
                                                                                                                        reference_rejected_logp,
                                                                                                                        beta=training_config["beta"], 
                                                                                                                        loss_type=training_config["loss_type"])
                val_loss_list[dim].append(loss.item())
                reward_accuracy_list[dim] += reward_accuracy
                reward_margin_list[dim] += reward_margin

    mean_val_loss = np.array([np.mean(losses) for losses in val_loss_list])
    median_val_loss = np.array([np.median(losses) for losses in val_loss_list])

    # compute perplexty for first batch 
    val_iterators = {dimension: iter(dataloader) for dimension, dataloader in dict_val_dataloaderst.items()}
    batches = fetch_batches(dict_val_dataloaderst, val_iterators)
    for dim in range(num_dims):
        dimension = dimensions[dim]
        prompt_prefered_ids = batches[dimension][f'prompt_chosen_ids']
        prompt_prefered_mask = batches[dimension][f'prompt_chosen_mask']
        perplexity_ratio, base_perplexity, perplexity = compute_perplexity_ratio(model=model, 
                                                    chosen_encoded_texts = prompt_prefered_ids, 
                                                    chosen_attn_masks = prompt_prefered_mask,  
                                                    tokenizer = tokenizer, 
                                                    batch_size = 6, 
                                                    add_start_token = True,
                                                    device=device, 
                                                    max_length=training_config["max_length"])
        acc_per[dim] = perplexity_ratio
        model_per[dim] = perplexity

        prompt_disprefered_ids = batches[dimension][f'prompt_rejected_ids']
        prompt_disprefered_mask = batches[dimension][f'prompt_rejected_mask']
        perplexity_ratio, base_perplexity, perplexity = compute_perplexity_ratio(model=model, 
                                                    chosen_encoded_texts = prompt_disprefered_ids, 
                                                    chosen_attn_masks = prompt_disprefered_mask,  
                                                    tokenizer = tokenizer, 
                                                    batch_size = 6, 
                                                    add_start_token = True,
                                                    device=device, 
                                                    max_length=training_config["max_length"])
        acc_neg[dim] = perplexity_ratio
        model_per_neg[dim] = perplexity
    # Turn on the training mode after computing the validation
    torch.cuda.empty_cache()
    model.train()
    return mean_val_loss, median_val_loss, reward_accuracy_list/min_iteration, reward_margin_list/min_iteration, acc_per, acc_neg, model_per, model_per_neg

def train(model, ref_model, tokenizer, optimizer, dict_train_dataloaderst, dict_val_dataloaderst, 
          training_config, lora_config, logger=None,  checkpoint_manager=None): 
    assert logger is not None, "Please provide a list of loggers to record training progress."
    model.train()
    if ref_model is not None:
        ref_model.eval()
    
    P, beta, mu = training_config["P"], training_config["beta"],  training_config["mu"]
    use_mix_precision, device, accumulation_steps = training_config["use_mix_precision"], training_config["device"], training_config["accumulation_steps"]
    val_step = training_config["val_step"]
    use_lora = training_config["use_lora"]
    total_length = 0
    for name, param in model.named_parameters():
        if param.requires_grad:
            total_length += param.numel()
    logger.info(f"Total length of all trainable parameters: {total_length}")
    logger.info(f"accumulationstep is {accumulation_steps}")
    dimensions, num_dims = list(dict_train_dataloaderst.keys()), len(dict_train_dataloaderst.keys())
    accumulated_losses, accumulated_perp = np.zeros(num_dims), np.zeros(num_dims)  # To accumulate losses across multiple steps
    reward_margin_list, reward_accuracy_list = np.zeros(num_dims),  np.zeros(num_dims) # get the reward margins across multiple steps
    accumulated_flatten_grads_list = []  # To accumulate gradients across multiple steps
    num_grads = count_parameters_requires_grad(model)

    for dim in range(num_dims):
        #accumulated_flatten_grads_list.append(torch.zeros(num_grads).to(device))
        accumulated_flatten_grads_list.append(torch.zeros(num_grads, device='cpu'))
    
    # Initialize learning rate scheduler
    # scheduler = ReduceLROnPlateau(optimizer, 'min', patience=3, threshold=0.001, factor=training_config["lr_decay_rate"])
    scheduler = LambdaLR(optimizer, lr_lambda=lambda step: min(1.0, (step + 1) / (training_config["warmup_steps"] + 1)))
    
    early_stopping_counter = 0

    # Set up mixed precision training
    if use_mix_precision:
        assert device.type == "cuda", "Mixed precision training requires a GPU"
        dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
        print_color(f"Using mixed precision training with {dtype} precision", 'yellow')
    else:
        dtype = torch.float32
        print_color("Using full precision training", 'yellow')
    scaler = torch.amp.GradScaler('cuda', enabled=(dtype == torch.float16)) if use_mix_precision else None
    autocast = torch.amp.autocast(device_type=device.type, dtype=dtype) if use_mix_precision else nullcontext()
    logger.info(f"dimensions are {dimensions}")
    
    
    step = 0
    max_iteration = get_iteration(dict_train_dataloaderst)
    train_iterators = {dimension: iter(dataloader) for dimension, dataloader in dict_train_dataloaderst.items()}
    for epoch in range(training_config["epochs"]):
        for _ in tqdm(range(max_iteration), desc="Fetching training batches"):
            step_start_time = time.time()
            batches = fetch_batches(dict_train_dataloaderst, train_iterators)
            # losses, prefered_ relative_logprobs, disprefered_relative_logprobs, reward_accuracies, reward_margins = [], [], [], [], []
            # Compute log probabilities for each dimension
            for dim in range(num_dims):
                optimizer.zero_grad()  # always zero the gradients before computing the new ones
                dimension = dimensions[dim]
                prompt_prefered_ids = batches[dimension][f'prompt_chosen_ids']
                prompt_disprefered_ids = batches[dimension][f'prompt_rejected_ids']
                prompt_prefered_mask = batches[dimension][f'prompt_chosen_mask']
                prompt_disprefered_mask = batches[dimension][f'prompt_rejected_mask']


                with autocast:
                    with torch.no_grad():
                        if use_lora:
                            assert ref_model is None, "Reference model must be None when using LoRA"
                            reference_chosen_logp = get_log_prob_ref(model, prompt_prefered_ids, prompt_prefered_mask, tokenizer)
                            reference_rejected_logp = get_log_prob_ref(model, prompt_disprefered_ids, prompt_disprefered_mask, tokenizer)
                        else:
                            reference_chosen_logp = get_log_prob(ref_model, prompt_prefered_ids, prompt_prefered_mask, tokenizer)
                            reference_rejected_logp = get_log_prob(ref_model, prompt_disprefered_ids, prompt_disprefered_mask, tokenizer)
                    policy_chosen_logp = get_log_prob(model, prompt_prefered_ids, prompt_prefered_mask, tokenizer)
                    policy_rejected_logp = get_log_prob(model, prompt_disprefered_ids, prompt_disprefered_mask, tokenizer)
                
                    loss, prefered_relative_logprob, disprefered_relative_logprob, reward_accuracy, reward_margin = dpo_loss(policy_chosen_logp,
                                                                                                                             policy_rejected_logp,
                                                                                                                             reference_chosen_logp,
                                                                                                                             reference_rejected_logp,
                                                                                                                             beta=beta, 
                                                                                                                             loss_type=training_config["loss_type"])

                    # add perplexity penalty
                    if training_config["perplexity_penalty"]:
                        perplexity_ratio, base_perplexity, perplexity = compute_perplexity_ratio(model=model, 
                                                                        chosen_encoded_texts = prompt_prefered_ids, 
                                                                        chosen_attn_masks = prompt_prefered_mask,  
                                                                        tokenizer = tokenizer, 
                                                                        batch_size = 6, 
                                                                        add_start_token = True,
                                                                        device=device, 
                                                                        max_length=training_config["max_length"])
                        logger.info(f"loss; penalty is: {loss}, {perplexity_ratio}, {base_perplexity}, {perplexity}")
                        loss += perplexity_ratio / 50 # hard coded weight for penalty
                        accumulated_perp[dim] += perplexity_ratio
                    loss = loss / accumulation_steps  # Scale loss for accumulation

                if use_mix_precision:
                    scaler.scale(loss).backward()
                else:
                    loss.backward()
                accumulated_losses[dim] += loss.item()
                reward_margin_list[dim] += reward_margin
                reward_accuracy_list[dim] += reward_accuracy
                accumulated_flatten_grads_list[dim] += flatten_gradients(model)
                # print(f"length of flatten_grads: {len(accumulated_flatten_grads_list[dim])}")
            step += 1
            # If we've accumulated enough steps, perform MO-OP and update
            if step % accumulation_steps == 0:
                mo_op_start_time = time.time()
                # MO-OP algorithm
                # Flatten the gradients into a single vector tensor
                # print(flatten_grads_list[0])
                Gt = torch.stack([grads for grads in accumulated_flatten_grads_list], dim=1).to(device)
                #logger.info(Gt)
                #logger.info(f"Gt shape {Gt.shape}; GTT shape:{Gt.T.shape}")
 
                # Compute the matrix square root using SVD
                #torch.backends.cuda.set_preferred_linalg_library('torch.linalg')
                G_T_G = torch.mm(Gt.T, Gt)
                matrix_u, vector_s, matrix_v = torch.svd(G_T_G)
                # Compute the square root of S^2
                S_sqrt = torch.diag(torch.sqrt(vector_s))
                sqrt_Gt_tG = torch.matmul(matrix_u, torch.matmul(S_sqrt, matrix_v.T))

                # Compute diag(sqrt(P))
                # Convert P to a PyTorch tensor
                P_tensor = torch.tensor(P, dtype=torch.float32).to(device)
                #sqrt_P_diag = torch.diag(torch.sqrt(P_tensor))

                # Compute Kt
                #Kt = sqrt_P_diag @ sqrt_Gt_tG @ sqrt_P_diag
                Kt = sqrt_Gt_tG
                # logger.info(f"Kt dimension is {Kt.shape}")
                # logger.info(Kt)
                # print(flatten_grads_list)
                # print(Gt)

                # Solve for the optimal weighting vector
                Hadamard_product = P * accumulated_losses
                Hadamard_product_tensor = torch.tensor(Hadamard_product, dtype=torch.float32).to(device)
                # Print tensor information
                lambda_star = solve_weighting_vector_scipy(Kt, mu, Hadamard_product_tensor, P_tensor)
                lambda_star = lambda_star.to(device)
                
                # logging some info
                with torch.no_grad():
                    grad_sum = Kt @ (lambda_star * P_tensor)
                    part1 = torch.sum(grad_sum ** 2).item()
                    part2 = mu * (lambda_star @ Hadamard_product_tensor)

                # Compute aggregated gradient g_t
                gt = torch.matmul(Gt, P_tensor * lambda_star)
                # print(f"GT: {Gt.cpu().numpy()}")
                logger.info(
                    f"P_tensor: {P_tensor.cpu().numpy()}; lambda_star: {lambda_star.cpu().numpy()}; delta: {(P_tensor * lambda_star).cpu().numpy()}; loss part1: {part1}; loss part2: {part2}")
                # print(f"gt: {gt.cpu().numpy()}")
                mo_op_time_taking = time.time() - mo_op_start_time

                # Restore the flat tensor back to the model's gradients
                restore_flat_gradients(model, gt)

                # Apply the gradient update. 
                torch.nn.utils.clip_grad_norm_(model.parameters(), 10.0)
                if use_mix_precision:
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    optimizer.step()
                
                # check L2 norm term
                norm_grad_sum = np.linalg.norm(Kt.cpu().detach().numpy() @ lambda_star.cpu().detach().numpy(), ord=2)
                # evaluate the loss on train/val sets and write checkpoints
                train_loss = np.max(np.array(accumulated_losses) * P)

                log_message = create_log_message(epoch, step, mo_op_time_taking, step_start_time, 
                                                 accumulated_losses, train_loss, reward_margin_list, 
                                                 reward_accuracy_list, norm_grad_sum, accumulated_perp, 
                                                 accumulated_flatten_grads_list, accumulation_steps)
                
                # validation
                if step % (val_step * accumulation_steps) == 0:
                    early_stop, counter = validate_and_save_checkpoint(epoch, step, model, ref_model, tokenizer, dict_val_dataloaderst, 
                                                                        dimensions, training_config, logger, log_message,
                                                                        checkpoint_manager, early_stopping_counter,
                                                                        accumulated_losses, reward_margin_list, reward_accuracy_list,
                                                                        use_lora, lora_config, optimizer, train_loss)
                    early_stopping_counter = counter
                    if early_stop:
                        return  # Exit the training
                else:
                    logger.info(log_message)

                # Reset the accumulated losses and gradients to zero
                accumulated_losses, accumulated_perp = np.zeros(num_dims), np.zeros(num_dims)  # To accumulate losses across multiple steps
                accumulated_flatten_grads_list = []  # To accumulate gradients across multiple steps
                for dim in range(num_dims):
                    #accumulated_flatten_grads_list.append(torch.zeros(num_grads).to(device))
                    accumulated_flatten_grads_list.append(torch.zeros(num_grads, device='cpu'))
                reward_margin_list, reward_accuracy_list = np.zeros(num_dims), np.zeros(num_dims)

                # Update the learning rate scheduler
                before_lr = optimizer.param_groups[0]["lr"]
                scheduler.step()
                after_lr = optimizer.param_groups[0]['lr']
                logger.info("Epoch {}: lr {:.10f} -> {:.10f}".format(epoch, before_lr, after_lr))

                # Logging
            torch.cuda.empty_cache() # Clear the cache to avoid memory leaks


def validate_and_save_checkpoint(epoch, step, model,
                                 ref_model, tokenizer, dict_val_dataloaderst, 
                                 dimensions, training_config, logger, log_message,  
                                 checkpoint_manager, early_stopping_counter, 
                                 accumulated_losses, 
                                 reward_margin_list, reward_accuracy_list, use_lora, 
                                 lora_config, optimizer, train_loss):
    P = training_config["P"]
    device, accumulation_steps = training_config["device"], training_config["accumulation_steps"]
    
    val_losses, median_losses, val_reward_accuracy, val_reward_margins, val_perp, val_perp_neg, model_per, model_per_neg = estimate_validation_loss(
        model, ref_model, dict_val_dataloaderst, dimensions, training_config, tokenizer, device)
    val_loss = np.max(np.array(val_losses) * P)
    
    # weighted loss
    max_topK_loss = checkpoint_manager.get_max_topK_loss()
    logger.info(f"validation loss is:[{val_loss} , {max_topK_loss}]")
    if val_loss - max_topK_loss < training_config["early_stopping_tolerance"]:
        checkpoint = {
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'accumulated_losses': accumulated_losses,
            'val_losses':val_losses,
            'train_loss': train_loss,
            'val_loss':val_loss,
            'reward_margin': np.array(reward_margin_list) / accumulation_steps,
            'reward_accuracy': np.array(reward_accuracy_list) / accumulation_steps,  
            'val_reward_accuracy': val_reward_accuracy,
            'val_reward_margins': val_reward_margins,
            'epoch': epoch,
            'step': step,
        }

        if use_lora:
            checkpoint['lora_config'] = lora_config
            
        logger.info(f"saving checkpoint to {checkpoint_manager.out_dir}")
        losses_str = '-'.join(f"{round(loss, 4)}" for loss in accumulated_losses)
        val_losses_str = '-'.join(f"{round(loss, 4)}" for loss in val_losses)
        curr_epoch = (
            f"epoch_{epoch}_step_{step}_valloss_{val_losses_str}_trailoss{losses_str}.pt"
        )
        if checkpoint_manager is not None:
            if use_lora:
                curr_epoch = curr_epoch.replace('.pt', 'peft_checkpoint.pt')
            checkpoint_manager.save_checkpoint(checkpoint, curr_epoch, val_loss, use_lora, model)
        early_stopping_counter = 0 # reset the counter
    else:
        early_stopping_counter += 1
        if early_stopping_counter >= training_config["early_stopping_patience"]:
            logger.info("Early stopping triggered")
            return True, early_stopping_counter  # Exit the training 
    
    log_message+= f"| val loss {val_losses}" + f"| val wc loss {val_loss:.4f}" \
        + f"| val_reward_accuracy {val_reward_accuracy}" + f"| val_reward_margins {val_reward_margins}" \
        + f"| val perp ratio {val_perp}" + f"|val neg perp ratio {val_perp_neg}" + f"| val perp {model_per}" + f"|val neg perp {model_per_neg}"
    logger.info(log_message)

    return False, early_stopping_counter


def create_log_message(epoch, step, mo_op_time_taking, step_start_time, 
                       accumulated_losses, train_loss,
                       reward_margin_list, reward_accuracy_list, 
                       norm_grad_sum, accumulated_perp,
                       accumulated_flatten_grads_list, accumulation_steps):
    log_message = (
        f"| epoch {epoch:05d} " \
        f"| step {step:010d} " \
        f"| mo optimization time {mo_op_time_taking:.4f} " \
        f"| step time {time.time() - step_start_time:.4f}" \
        f"| training loss {accumulated_losses}" \
        f"| training wc loss {train_loss:.4f}"\
        f"| reward_margins {np.array(reward_margin_list) / accumulation_steps}"\
        f"| reward_accuracy {np.array(reward_accuracy_list) / accumulation_steps}"\
        f"| L2 norm {norm_grad_sum:.6f}"\
        f"| perplexity {np.array(accumulated_perp) / accumulation_steps}"\
        f"| grad_max_abs: {np.linalg.norm(accumulated_flatten_grads_list, ord = np.inf)}"
    )
    return log_message


def parse_args():
    parser = argparse.ArgumentParser(description="Training Script")

    parser.add_argument('--epochs', type=int, default=3, help='Number of training epochs')
    parser.add_argument('--beta', type=float, default=0.1, help='Beta value for the optimizer or other relevant use')
    parser.add_argument('--batch_size', type=int, default=4, help='Batch size for training')
    parser.add_argument('--max_length', type=int, default=512, help='Maximum sequence length')
    parser.add_argument('--lr', type=float, default=1e-6, help='Learning rate')
    parser.add_argument('--P_vec', type=float, nargs='+', default=[0.5, 0.5], help='List of floats for P_vec')
    parser.add_argument('--use_lora', action="store_true", help='Enable or disable Lora (true/false)')
    parser.add_argument('--accumulation_steps', type=int, default=5,
                        help='Number of steps to accumulate gradients before performing MO-OP')
    parser.add_argument('--use_mix_precision', action="store_true", help='Enable or disable mixed precision training')
    parser.add_argument('--load_in_X_bit', action="store_true", help='Enable X bit precision or Not')
    
    parser.add_argument('--model', type=str, required=True, help='Specify the model name')
    parser.add_argument('--dataset', type=str, required=True, help='Specify the dataset name')
    parser.add_argument('--seed', type=int, default=0, help= 'Specify the random seed')
    parser.add_argument('--loss_type', type=str, choices=['sigmoid', 'ipo'], default='sigmoid', help='Specify the loss type. Options are "sigmoid" or "ipo". Default is "sigmoid".')
    parser.add_argument('--mu', type=float, default=1e-2, help='Tradeoff')

    parser.add_argument('--perplexity_penalty', action="store_true", help='perplexity penalty)')
    parser.add_argument("--checkpoint_path", type=str, default="checkpoint", help="Path to save or load model checkpoints")
    parser.add_argument("--exclude_cols", nargs='+', type=str, default=[], help="List of column names to be exluded")
    args = parser.parse_args()
    return args

def main():
    args = parse_args()
    training_config = {
        "model_name": args.model,
        "dataset_name": args.dataset,
        "epochs": args.epochs,
        "beta": args.beta,
        "batch_size": args.batch_size,
        "max_length": args.max_length,
        "lr": args.lr,
        "P": np.array(args.P_vec),
        "use_lora": args.use_lora,
        "accumulation_steps": args.accumulation_steps,
        "use_mix_precision": args.use_mix_precision,
        'load_in_X_bit': args.load_in_X_bit, # 8 bit doesn't work on H100, # TODO: Quantization, Remove in full scale experiment.
        "loss_type": args.loss_type,
        "seed": args.seed,
        "mu": args.mu,
        
        "val_data_size": 50, # size of validation set, e.g 100
        "val_batch_size": 20, # Validation Batch Size
        "device": torch.device("cuda" if torch.cuda.is_available() else "cpu"), 
        "val_step": 50, # No. of training iterations * accumulation_steps that are are required before validation
        "lr_decay_rate":0.8,
        "early_stopping_patience": 10,
        "early_stopping_tolerance": 0.001,

        "perplexity_penalty": args.perplexity_penalty,
        "warmup_steps":150,
        "checkpoint_path": args.checkpoint_path,
        "exclude_cols": args.exclude_cols,
    }
    print_color(training_config, 'yellow')
    model_name, device, dataset_name, seed =  training_config["model_name"], training_config["device"], training_config["dataset_name"], training_config["seed"]
     # Get the current PID
    pid = os.getpid()
    epoch_time = int(time.time())
    log_dir = f"logs/{training_config['checkpoint_path']}"
    model_name_path = model_name.replace("/", "")
    pref_string = "_".join(str(a) for a in args.P_vec)
    check_point_dir =  f"checkpoint/{training_config['checkpoint_path']}/{epoch_time}_pid_{pid}_model_{model_name_path}_pref_{pref_string}"
    ensure_and_clean_directory(check_point_dir)
    checkpoint_manager = CheckpointManager(out_dir=check_point_dir, max_checkpoints=5)
    training_logger = create_logger(log_dir, f"{epoch_time}_pid_{pid}_model_{model_name_path}_pref_{pref_string}")
    training_logger.info("Training configuration: %s", training_config)
    seed_everything(seed)
    
    # load lora config
    if "zephyr" in model_name.lower():
        config = lora_configs["zephyr"]
    elif "llama" in model_name.lower():
        training_logger.info("use llama version lora config")
        config = lora_configs["llama"]
    else:
        config = lora_configs["gpt2"]

    training_logger.info("Lora config: %s", config )

    # base model: deperated
    if training_config["use_lora"]:
        base_model = None
    else:
        if training_config["load_in_X_bit"]:
            bnb_config = BitsAndBytesConfig(load_in_8bit=True) 
            base_model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=bnb_config)
        else:
            base_model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
            
        training_logger.info(base_model)
        # Make base model not trainable
        for param in base_model.parameters():
            param.requires_grad = False

    # Load model
    if training_config["load_in_X_bit"]: 
        #quantization_config = BitsAndBytesConfig(load_in_16bit=True) 
        training_logger.info("Load 8 bit quantization")
        model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", load_in_8bit=True)
    else:
        model = AutoModelForCausalLM.from_pretrained(model_name).to(device)

    if training_config["use_lora"]:
        # Apply LoRA using get_peft_model
        model = get_peft_model(model, config)
        model.print_trainable_parameters()
        if hasattr(model, "enable_input_require_grads"):
            model.enable_input_require_grads()

    # change tokenizer
    try:
        tokenizer = AutoTokenizer.from_pretrained(model_name)
    except OSError as e:
        print("Tokenizer not found!")
        training_logger.info(f"Error loading tokenizer: {e}")
        # Handle the error, e.g., load a default tokenizer or take other actions
        raise e # Fallback to a default tokenizer
    tokenizer.pad_token = tokenizer.eos_token

    # Filter out some columns
    # exclude_cols = ["instruction_following", "truthfulness", "overall"]
    exclude_cols = training_config["exclude_cols"]
    # dataset
    training_dataset = load_and_process_dataset_from_name(dataset_name, split='train', seed=seed, removed_dimensions = exclude_cols)
    validation_dataset = load_and_process_dataset_from_name(dataset_name, split='validation', seed=seed, removed_dimensions = exclude_cols)
    assert training_dataset.keys() == validation_dataset.keys(), "training and validation dataset have different keys"
    dimensions = training_dataset.keys()
    print(dimensions)
    assert len(training_config["P"]) == len(dimensions), "P_vec must have the same number of elements as the number of dimensions."
    # dataset_dict = dataset_dict.select(range(12))
    training_logger.info("size of training data set:")
    training_logger.info(training_dataset)
    dict_train_dataloaders = {}
    dict_val_dataloaders = {}
    for dimension in dimensions:
        dict_train_dataloaders[dimension] = torch.utils.data.DataLoader(training_dataset[dimension], batch_size=training_config["batch_size"],
                                                                        shuffle=True,
                                                                        collate_fn=partial(collate_fn, tokenizer=tokenizer,
                                                                                            max_length=training_config["max_length"], device=device))
        # Disable shuffling, and use a larger batch size
        dict_val_dataloaders[dimension] = torch.utils.data.DataLoader(validation_dataset[dimension], batch_size= training_config["val_batch_size"], 
                                                                      shuffle=False,
                                                                      collate_fn=partial(collate_fn, tokenizer=tokenizer,
                                                                                         max_length=training_config["max_length"], device=device))
        random_samples = random.sample(range(len(training_dataset[dimension])), 2)
        for i in random_samples:
            training_logger.info(training_dataset[dimension][i])
        random_samples = random.sample(range(len(validation_dataset[dimension])), 2)
        for i in random_samples:
            training_logger.info(validation_dataset[dimension][i])

    optimizer = AdamW(model.parameters(), lr=training_config["lr"], weight_decay=1e-5)
    #optimizer = RMSprop(model.parameters(), lr=training_config["lr"], eps=1e-8)

    # Convert list to PyTorch tensor
    train(model, base_model, tokenizer, optimizer, dict_train_dataloaders, dict_val_dataloaders,
          training_config, lora_config = config, logger=training_logger, checkpoint_manager=checkpoint_manager)
    
    path = f"model/pid_{pid}_model_{model_name_path}_pref_{pref_string}.pt"
    create_directory(path)
    model.save_pretrained(path)


if __name__ == "__main__":
    main()