# Taken and modified from Coste's(tlc4418) llm_optimization repository 
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset, Dataset
from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer
from typing import Union
import time
from alpaca_farm.models.reward_model import RewardModel
from accelerate import Accelerator, DistributedType
from src.data_utils.rm_dataset_formatter import RMPromptDataset
# from src.reward_modeling.scoring.infer2 import EBM_DNN
from model_training.models.reward_model import (
    GPTNeoXRewardModel,
    GPTNeoXRewardModelConfig,
)

MAX_LEN = 776  # 520 instruction + 256 answer


def get_reward(
    samples,
    reward_models,
    reward_tokenizer,
    reward_device,  # needed?
    batch_size,
    objective_function=None,
    weight=None,
    is_alpacafarm_rm=False,
    ebm_model =None,T=50, lam=0.5, eta=0.1
):

    if not isinstance(reward_models, list):
        reward_models = [reward_models]

    input = reward_tokenizer(
        samples,
        padding=True,
        truncation=True,
        max_length=MAX_LEN,
        return_tensors="pt",
    ).to(reward_device)

    all_rewards = []

    all_out_proj_inputs = []
    def get_out_proj_input_hook(module, input, output):
        global out_proj_input
        out_proj_input = input[0]
    # out_proj_input=None
    # print(T)
    for reward_model in reward_models:

        hook_handle = reward_model.out_proj.register_forward_hook(get_out_proj_input_hook)

        embeddings =[]
        initial_r =[]
        ##change made end
        out = []
        for i in range(math.ceil(len(samples) / batch_size)):
            batch_ixs = slice(i * batch_size, (i + 1) * batch_size)
            input_ids = input.input_ids[batch_ixs]
            attention_mask = input.attention_mask[batch_ixs]
            output = reward_model(input_ids, attention_mask)
            rewards = output.rewards if is_alpacafarm_rm else output.logits[:, 0]
            out.extend(rewards)
            if out_proj_input is not None:
                embeddings.extend(out_proj_input.detach()) 
            initial_r.extend(rewards.detach().cpu().tolist())
        all_rewards.append(torch.hstack(out))
        if embeddings:
            all_out_proj_inputs.append(torch.vstack(embeddings))
            
    hook_handle.remove()
    if len(all_rewards) == 1:
        all_rewards = all_rewards[0]
        if all_out_proj_inputs:
            embeddings_tensor = all_out_proj_inputs[0].to(reward_device)
            embeddings = [embeddings_tensor[i] for i in range(embeddings_tensor.size(0))]
        else:
            embeddings = []
        if ebm_model is not None:
            inferred_rewards = infer_rewards_batch_dramatic(
                    ebm_model, embeddings_tensor, y_init= all_rewards,init_range=[-4.0, 4.0], T=T,batch_size=1024 ,lambda_init=lam, eta=eta
                )
            return inferred_rewards, torch.empty_like(all_rewards)
        else:
            return all_rewards, torch.empty_like(all_rewards),embeddings
        

    all_rewards = torch.stack(all_rewards, 0)
    print(all_rewards.shape)
    all_out_proj_inputs = torch.stack(all_out_proj_inputs, 0) if all_out_proj_inputs else torch.empty(0)

    var = torch.var(all_rewards, dim=0)
    if objective_function:
        all_rewards = objective_function(all_rewards, weight)
    return all_rewards, var 


def score_answers(
    model_name: str,
    dataset: Union[str, Dataset],
    ebm_model: str = None,
    split: str = "validation",
    scores_type: str = "gold_scores",
    sort: bool = False,
    split_size: int = 32,
    is_alpacafarm_rm: bool = False,
) -> list:
    dataset = load_dataset(dataset)[split] if isinstance(dataset, str) else dataset

    prompt_dataset = RMPromptDataset(
        dataset,
        output_alpaca=is_alpacafarm_rm,
    )
    model = (
        RewardModel.from_pretrained(model_name, flash_attn=False, bf16=True)
        if is_alpacafarm_rm
        else AutoModelForSequenceClassification.from_pretrained(model_name)
    )
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    accelerator = Accelerator()

   
    if accelerator.distributed_type == DistributedType.DEEPSPEED:
        accelerator.state.deepspeed_plugin.deepspeed_config["zero_optimization"]["stage"] = 0

    model, tokenizer = accelerator.prepare(model, tokenizer)
    model.eval()
    model.requires_grad_(False)
    if ebm_model is not None:
      ebm_model.eval()
    samples = [prompts for _, prompts in prompt_dataset]
    has_multi_answers = len(samples[0]) > 1

    if has_multi_answers:
        rewards = [
            get_reward(
                prompts,
                model,
                tokenizer,
                model.device,
                split_size,
                is_alpacafarm_rm=is_alpacafarm_rm,
                ebm_model =ebm_model
            )[0]
            for prompts in samples
        ]
        
    else:
        rewards, _ = get_reward(
            [prompts[0] for prompts in samples],
            model,
            tokenizer,
            model.device,
            split_size,
            is_alpacafarm_rm=is_alpacafarm_rm,
        )

    data = []
    for i, (entry, _) in enumerate(prompt_dataset):
        scores = rewards[i].cpu().detach()

        if has_multi_answers:
            if sort:
                scores, indices = torch.sort(scores)
                entry["answers"] = [entry["answers"][i] for i in indices]
                if entry.get("gold_scores"):
                    entry["gold_scores"] = [entry["gold_scores"][i] for i in indices]
                if entry.get("proxy_scores"):
                    entry["proxy_scores"] = [entry["proxy_scores"][i] for i in indices]

        entry[scores_type] = scores.tolist() if has_multi_answers else [scores.item()]
        data.append(entry)

    return data


AutoConfig.register("gpt_neox_reward_model", GPTNeoXRewardModelConfig)
AutoModelForSequenceClassification.register(GPTNeoXRewardModelConfig, GPTNeoXRewardModel)

import torch

def infer_rewards_batch_dramatic(
    ebm_model, 
    embeddings, 
    batch_size=32, 
    T=100, 
    lambda_init=0.1, 
    eta=0.5, 
    init_range=[-2.0, 2.0], 
    y_init=None
):
    

    ebm_model.eval()
    device = embeddings.device
    num_samples = embeddings.size(0)
    

    
    if y_init is not None:
        y_init_ = y_init.to(device).detach()
        random_init = torch.empty(num_samples, device=device).uniform_(
            init_range[0], init_range[1]
        )
        
        mask_in_range = (y_init_ >= -2) & (y_init_ <= 2)
        y = torch.where(mask_in_range, y_init_, random_init)
        # print(y)
    else:
        y = torch.empty(num_samples, device=device).uniform_(
            init_range[0], init_range[1]
        )
    
    y.requires_grad_(False)
    lambda_steps = torch.full((num_samples,), lambda_init, device=device)

    active_mask = torch.ones(num_samples, dtype=torch.bool, device=device)

    for t in range(T):
        active_indices = torch.nonzero(active_mask).squeeze(-1)
        if active_indices.numel() == 0:
            break

        for i in range(0, len(active_indices), batch_size):
            batch_idx = active_indices[i : i + batch_size]
            batch_emb = embeddings[batch_idx]

            y_batch = y[batch_idx].detach().requires_grad_(True)
            prev_values = ebm_model(batch_emb, y_batch).squeeze()

            grad = torch.autograd.grad(prev_values.sum(), y_batch, retain_graph=False)[0]

            with torch.no_grad():
                y_tilde = y_batch + lambda_steps[batch_idx] * grad
                new_values = ebm_model(batch_emb, y_tilde).squeeze()

                improved = new_values > prev_values
                
                y[batch_idx] = torch.where(improved, y_tilde, y_batch)
                
                lambda_steps[batch_idx] = torch.where(
                    improved, 
                    lambda_steps[batch_idx], 
                    lambda_steps[batch_idx] * eta
                )
        active_mask &= (lambda_steps >= 1e-6)

    return y.detach()



    device = embeddings.device
    
    best_y = None
    best_fval = None

    for restart_idx in range(num_restarts):
        # 1) Generate a random initialization of y
        if restart_idx==0:
            random_init = y_init
        else:
            low, high = init_range
            random_init = torch.empty(embeddings.size(0), device=device).uniform_(low, high)

        # 2) Run your existing inference function with that random init
        y_candidate = infer_rewards_batch(
            ebm_model,
            embeddings,
            y_init=random_init,
            **infer_kwargs
        )

        # 3) Evaluate the final energy
        with torch.no_grad():
            f_val = ebm_model(embeddings, y_candidate).mean().item()

        # 4) Track the best result (lowest energy)
        if best_fval is None or f_val > best_fval:
            best_fval = f_val
            best_y = y_candidate

        # print(f"[Restart {restart_idx+1}/{num_restarts}] "
        #       f"Final mean energy: {f_val:.4f} "
        #       f"{'(new best!)' if best_fval ==f_val else ''}")

    return best_y