import torch
import copy
    
class Union_Model(torch.nn.Module):

    def __init__(self,policy_model,ref_model,policy_tokenizer,accelerator):
        super().__init__()
        self.policy = policy_model
        self.ref = ref_model
        self.policy_tokenizer = policy_tokenizer
        self.accelerator = accelerator
        self.policy.requires_grad_(True)
        self.ref.requires_grad_(False)
    
    def get_logp(self, inputs, prompt_len, model_type = "policy", gather = True):

        if model_type == "policy":
            model = self.policy
        else:
            model = self.ref
        
        input_ids,attention_mask = inputs["input_ids"],inputs["attention_mask"] 

        all_logp = model(**inputs).logits.log_softmax(-1)[:,prompt_len-1:-1,:]
        target_logp = torch.gather(all_logp, dim=2, index=input_ids[:,prompt_len:].unsqueeze(2)).squeeze(2)
        target_logp = target_logp * attention_mask[:,prompt_len:]

        if model_type == "policy":
            entropy = (all_logp.exp().detach() * all_logp).sum(-1) - target_logp.exp().detach() * target_logp
            entropy = entropy * attention_mask[:,prompt_len:]
            return target_logp, entropy
        else:
            return target_logp
        
    def generate(self,inputs,prompt_len,generator,generation_kwargs):

        generator = self.accelerator.unwrap_model(generator)

        generator.config.use_cache = True
        generator.gradient_checkpointing_disable()
        outputs_ids = generator.generate(**inputs, **generation_kwargs)
        generator.config.use_cache = False
        generator.gradient_checkpointing_enable()
        
        prompts = self.policy_tokenizer.batch_decode(inputs["input_ids"],skip_special_tokens=True)
        responses = self.policy_tokenizer.batch_decode(outputs_ids[:,prompt_len:],skip_special_tokens=True)
        length = (outputs_ids[:,prompt_len:] != self.policy_tokenizer.eos_token_id).float().sum(-1,keepdims=True)

        return outputs_ids,prompts, responses,length        