import re
import torch
import random
import numpy as np
import os

def model_setting(model, grad_ckpt = True):
    if grad_ckpt:
        model.supports_gradient_checkpointing = True
        model.gradient_checkpointing_enable()
        model.enable_input_require_grads()
        model.config.use_cache = False
    model.is_parallelizable = True
    model.model_parallel = True

def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     np.random.seed(seed)
     random.seed(seed)
     os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' 

def tokenizer_setting(tokenizer):
    tokenizer.padding_side = "right"
    tokenizer.truncation_side = "right"
    tokenizer.add_bos_token = False
    tokenizer.add_eos_token = True

class Align_Collator:
    def __init__(self, tokenizer, config):
        self.tokenizer = tokenizer
        self.config = config

    def common_prefix_length(self,tensor):
        for i in range(tensor.shape[1]):
            if not torch.all(tensor[:,i].eq(tensor[0,i])):
                return i
        return tensor.shape[1]

    def collate_fn(self, batch):

        batch = batch[0]
        prompt = batch["prompt"]
        responses = [batch["chosen"],batch["rejected"]]
        rewards = [1,0]
        
        tokenizer_setting(self.tokenizer)
        responses = [self.tokenizer.apply_chat_template(r, tokenize=False) for r in responses]
        inputs = self.tokenizer(responses,padding=True,return_tensors="pt",max_length=self.config.max_length,truncation=True).to("cuda")
        prompt_length = self.common_prefix_length(inputs["input_ids"])

        return inputs, rewards, prompt_length-1
