#!/usr/bin/env python3
"""
Single-file MVP implementation for Multi-Objective Preference Optimization (MOPO) in the (simplified) RLHF setting. 
Full code will be released later.
"""

import math
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt  
import json


class Config:
    NUM_POINTS = 200        
    REWARD_SET = "A"        
    BT_TEMP = 0.5
    DATASET_SIZE = 500
    K = 2                   
    TAU = 0.1                
    EPS = 1               
    THRESHOLDS = [10]       
    MAX_EPOCHS = 10
    BATCH_SIZE = 512
    LR = 1e-2
    HIDDEN_DIM = 128
    EVAL_CONTEXTS = 100
    FRONT_SAMPLES = 2500
    EVAL_FRAC = [0.2, 0.4, 0.6, 0.8, 1.0]
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


def r1A(x, y):  
    return torch.exp(x) + torch.sqrt(y) - y

def r2A(x, y): 
    return -torch.sin(x) - y**2

def r1B(x, y):  
    return (x + y)**2

def r2B(x, y): 
    return torch.log((1+x)/(1+y))

def get_reward_funcs(set_name):
    if set_name=="A":
        return r1A, r2A
    else:
        return r1B, r2B


def generate_dataset(config):
    data = []
    r1_func, r2_func = get_reward_funcs(config.REWARD_SET)
    device = config.DEVICE
    for _ in range(config.DATASET_SIZE):
        x_idx = np.random.randint(0, config.NUM_POINTS)
        y_idx = np.random.randint(0, config.NUM_POINTS)
        yp_idx= np.random.randint(0, config.NUM_POINTS)
        if yp_idx==y_idx:
            yp_idx= (yp_idx+1)%config.NUM_POINTS
        x_val= x_idx/(config.NUM_POINTS-1)
        y_val= y_idx/(config.NUM_POINTS-1)
        yp_val= yp_idx/(config.NUM_POINTS-1)

        x_t = torch.tensor([x_val],device=device)
        y_t = torch.tensor([y_val],device=device)
        yp_t= torch.tensor([yp_val],device=device)

        r1y = r1_func(x_t,y_t).item()
        r2y = r2_func(x_t,y_t).item()
        r1yp= r1_func(x_t,yp_t).item()
        r2yp= r2_func(x_t,yp_t).item()

        diff1= (r1y - r1yp)/config.BT_TEMP
        p_prob= 1/(1+math.exp(-diff1))
        p_bit= 1 if (np.random.rand()< p_prob) else 0

        diff2= (r2y - r2yp)/config.BT_TEMP
        q_prob= 1/(1+math.exp(-diff2))
        q_bit= 1 if (np.random.rand()< q_prob) else 0

        data.append( (x_idx,y_idx,yp_idx, p_bit,q_bit) )
        data.append( (x_idx,yp_idx,y_idx, 1-p_bit,1-q_bit) )
    return data

import torch.utils.data as datautils

class MOPODataset(datautils.Dataset):
    def __init__(self, record_list):
        super().__init__()
        self.data = record_list
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        x_i, y_i, yp_i, p_bit, q_bit = self.data[idx]
        I = torch.tensor([q_bit,p_bit], dtype=torch.float32)
        return x_i, y_i, yp_i, I

class PolicyNet(nn.Module):
    def __init__(self, hidden_dim=256, num_actions=1000):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(1, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_actions)
        )
    def forward(self, x):
        return self.net(x.float())

def compute_rho_per_sample(I_p, I_q, lambda_vec, tau, N):
    exponent_factor= 1.0/(tau*N)
    scalar_contrib= I_p + torch.matmul(I_q, lambda_vec)
    raw_exp= torch.clamp(exponent_factor*scalar_contrib, -10,10)
    return torch.exp(raw_exp)

def lower_bound_chi_objective(I_q, rho, chi_vec, eps):
    total_obj=0.0
    for k in range(chi_vec.shape[0]):
        exponent_terms= (rho*I_q[:,k])/ torch.clamp_min(chi_vec[k],1e-10)
        exponent_terms= torch.clamp(exponent_terms, -10,10)
        exp_vals= torch.exp(exponent_terms)
        mean_exp= torch.mean(exp_vals)
        part_obj= chi_vec[k]* torch.log(mean_exp+1e-10)+ chi_vec[k]*eps
        total_obj+= part_obj
    return -total_obj

def estimate_q_lower_bounds(I_q, rho, chi_vec):
    LB_list= []
    for k in range(I_q.shape[1]):
        exponent_terms= (rho*I_q[:,k])/ torch.clamp_min(chi_vec[k],1e-10)
        exponent_terms= torch.clamp(exponent_terms, -10,10)
        exp_vals= torch.exp(exponent_terms)
        mean_exp= torch.mean(exp_vals)
        val_k= chi_vec[k]* torch.log(mean_exp+1e-10)
        LB_list.append(val_k)
    return torch.stack(LB_list, dim=0)

def constraint_lambda_loss(lambda_vec, b_vec, q_lower_bounds):
    diff= b_vec- q_lower_bounds
    return -torch.sum(lambda_vec* diff)

def policy_extraction_loss(policy_net, x_idx, y_idx, rho, config):
    device= config.DEVICE
    bsz= x_idx.shape[0]
    x_f= (x_idx.float()/(config.NUM_POINTS-1)).view(-1,1).to(device)
    logits= policy_net(x_f)
    log_probs= torch.log_softmax(logits, dim=1)
    chosen_lp= log_probs[torch.arange(bsz, device=device), y_idx]
    weighted_lp= rho* chosen_lp
    return -torch.mean(weighted_lp)

def get_policy_distribution_x0(policy_net, config):
    x_0= torch.tensor([[0.0]], device=config.DEVICE)
    logits= policy_net(x_0)
    probs= torch.softmax(logits, dim=1)
    return probs.squeeze(0).detach().cpu().numpy()

def r1A_func(x, y):
    return math.exp(x) + math.sqrt(y) - y

def r2A_func(x, y):
    return -math.sin(x) - (y**2)

def r1B_func(x, y):
    return (x+y)**2

def r2B_func(x, y):
    return math.log( (1+x)/(1+y) )

def approximate_front(config):
    M= config.FRONT_SAMPLES
    arr_x= np.random.rand(M)
    arr_y= np.random.rand(M)
    if config.REWARD_SET=="A":
        def r1(x,y): return math.exp(x) + math.sqrt(y) - y
        def r2(x,y): return -math.sin(x) - y**2
    else:
        def r1(x,y): return (x+y)**2
        def r2(x,y): return math.log( (1+x)/(1+y) )

    pts=[]
    for i in range(M):
        xx= arr_x[i]
        yy= arr_y[i]
        r1v= r1(xx, yy)
        r2v= r2(xx, yy)
        pts.append( (r1v, r2v) )
    pts= np.array(pts)
    is_dom= np.zeros(M, dtype=bool)
    for i in range(M):
        if is_dom[i]:
            continue
        for j in range(M):
            if j==i:
                continue
            if (pts[j,0]>= pts[i,0]) and (pts[j,1]>= pts[i,1]) and \
               ((pts[j,0]> pts[i,0]) or (pts[j,1]> pts[i,1])):
                is_dom[i]=True
                break
    return pts[~is_dom]

def evaluate_policy_points(policy_net, x_eval, config):
    device= config.DEVICE
    x_ts= torch.tensor(x_eval, dtype=torch.float32, device=device).view(-1,1)
    with torch.no_grad():
        logits= policy_net(x_ts)
        probs= torch.softmax(logits, dim=1)
        best_actions= torch.argmax(probs, dim=1).cpu().numpy()
    arr = []
    if config.REWARD_SET=="A":
        def r1(x,y): return math.exp(x) + math.sqrt(y) - y
        def r2(x,y): return -math.sin(x) - y**2
    else:
        def r1(x,y): return (x+y)**2
        def r2(x,y): return math.log( (1+x)/(1+y) )

    for i in range(len(x_eval)):
        xx= x_eval[i]
        yy_idx= best_actions[i]
        yy= yy_idx/(config.NUM_POINTS-1)
        arr.append( (r1(xx, yy), r2(xx, yy)) )
    return np.array(arr)

def train_mopo(dataset, config):
    import matplotlib.pyplot as plt

    device= config.DEVICE
    from torch.utils.data import DataLoader, RandomSampler
    N= len(dataset)
    sampler= RandomSampler(dataset, replacement=True, num_samples=N)
    loader= DataLoader(dataset, batch_size=config.BATCH_SIZE, sampler=sampler)

    lambda_vec= nn.Parameter(torch.zeros(config.K-1, device=device), requires_grad=True)
    chi_vec   = nn.Parameter(torch.ones(config.K-1, device=device)*1.0, requires_grad=True)
    policy_net= PolicyNet(hidden_dim=config.HIDDEN_DIM, num_actions=config.NUM_POINTS).to(device)

    opt_chi= optim.Adam([chi_vec], lr=config.LR)
    opt_lambda= optim.Adam([lambda_vec], lr=config.LR)
    opt_policy= optim.Adam(policy_net.parameters(), lr=config.LR)

    def clip_all_grads():
        torch.nn.utils.clip_grad_norm_([chi_vec],5.0)
        torch.nn.utils.clip_grad_norm_([lambda_vec],5.0)
        torch.nn.utils.clip_grad_norm_(policy_net.parameters(),5.0)

    np.random.seed(0) 
    x_eval = np.random.rand(config.EVAL_CONTEXTS)
    front_pts = approximate_front(config)

    check_epochs = set([int(config.MAX_EPOCHS*fr) for fr in config.EVAL_FRAC])
    if config.MAX_EPOCHS>=5 and len(check_epochs)<5:
        pass
    policy_eval_dict= {}

    for epoch in range(config.MAX_EPOCHS):
        total_chi=0.0
        total_lam=0.0
        total_pol=0.0
        n_batches=0
        for batch in loader:
            x_idx, y_idx, yp_idx, I_mat= batch
            bsz= x_idx.shape[0]
            x_idx= x_idx.to(device)
            y_idx= y_idx.to(device)
            I_mat= I_mat.to(device)

            I_p= I_mat[:,1]
            I_q= I_mat[:,:1]

            rho_vals= compute_rho_per_sample(I_p, I_q, lambda_vec, config.TAU, N=bsz)

            opt_chi.zero_grad()
            chi_loss= lower_bound_chi_objective(I_q, rho_vals, chi_vec, config.EPS)
            chi_loss.backward()
            clip_all_grads()
            with torch.no_grad():
                chi_vec.clamp_(1e-6, 15.0)
            opt_chi.step()

            with torch.no_grad():
                q_lb= estimate_q_lower_bounds(I_q, rho_vals, chi_vec)

            opt_lambda.zero_grad()
            lam_loss= constraint_lambda_loss(lambda_vec, torch.tensor(config.THRESHOLDS, device=device), q_lb)
            lam_loss.backward()
            clip_all_grads()
            with torch.no_grad():
                lambda_vec.clamp_(0.0,15.0)
            opt_lambda.step()

            rho_vals_after= compute_rho_per_sample(I_p, I_q, lambda_vec, config.TAU, N=bsz)

            opt_policy.zero_grad()
            pol_loss= policy_extraction_loss(policy_net, x_idx, y_idx, rho_vals_after, config)
            pol_loss.backward()
            clip_all_grads()
            opt_policy.step()

            total_chi+= chi_loss.item()
            total_lam+= lam_loss.item()
            total_pol+= pol_loss.item()
            n_batches+=1

        if (epoch in check_epochs) or (epoch==config.MAX_EPOCHS-1):
            avg_chi_loss= total_chi/n_batches
            avg_lam_loss= total_lam/n_batches
            avg_pol_loss= total_pol/n_batches
            print(f"[Epoch {epoch}/{config.MAX_EPOCHS}] chi_loss={avg_chi_loss:.4f}, lam_loss={avg_lam_loss:.4f}, pol_loss={avg_pol_loss:.4f}")

            policy_pts= evaluate_policy_points(policy_net, x_eval, config)
            policy_eval_dict[epoch] = policy_pts

    print("Training complete!")

    print("Plotting the scatter of front plus policy snapshots ...")
    plt.figure()
    frontX = front_pts[:,0]
    frontY = front_pts[:,1]
    plt.scatter(frontX, frontY, label="Pareto Front")

    pareto_policy_dict = {}

    sorted_keys= sorted(policy_eval_dict.keys())
    for ep in sorted_keys:
        pts= policy_eval_dict[ep]
        pareto_policy_dict[ep] = pts.tolist()
        t = 1 if config.REWARD_SET == 'A' else -1

        plt.scatter(pts[:,0], t*pts[:,1], label=f"Epoch {ep}")

    pareto_policy_dict['pareto'] = front_pts.tolist()

    filename = './data/reward_model_set_' + config.REWARD_SET + '.json'
    with open(filename, 'w') as fp:
        print(' filename :   ', filename)
        json.dump(pareto_policy_dict, fp)

    plt.xlabel("r1")
    plt.ylabel("r2")
    plt.legend()
    plt.title("MOPO Policy Points vs. Pareto Front")
    plt.show()

    return policy_net, lambda_vec, chi_vec

def main():
    parser= argparse.ArgumentParser()
    parser.add_argument("--reward_set", type=str, default=None)
    parser.add_argument("--max_epochs", type=int, default=None)
    args= parser.parse_args()

    config= Config()
    if args.reward_set is not None:
        config.REWARD_SET= args.reward_set
    if args.max_epochs is not None:
        config.MAX_EPOCHS= args.max_epochs

    record_list= generate_dataset(config)
    dataset= MOPODataset(record_list)
    print("Dataset size=", len(dataset))

    policy_net, lam_vec, chi_vec= train_mopo(dataset, config)

    print("Final lambda =", lam_vec.detach().cpu().numpy())
    print("Final chi    =", chi_vec.detach().cpu().numpy())

if __name__=="__main__":
    main()
