#!/usr/bin/env python3
"""
Single-file MVP implementation for Multi-Objective Preference Optimization (MOPO) in the bandit setting. 
Full code will be released later.
"""

import argparse
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from scipy.special import softmax
from torch.utils.data import Dataset, DataLoader, RandomSampler

CONFIG = {
    "K": 2,                 
    "y_dim": 3,             
    "tau": 1,              
    "eps": 0.05,              
    "thresholds": [1],      
    "max_epochs": 20000,
    "batch_size": 8,
    "learning_rate": 1e-3,
    "device": "cuda" if torch.cuda.is_available() else "cpu",
}


class PreferenceDataset(Dataset):

    def __init__(self, df: pd.DataFrame, K: int):

        super().__init__()
        self.y_data = torch.tensor(df["y"].values, dtype=torch.long)
        self.yprime_data = torch.tensor(df["yprime"].values, dtype=torch.long)
        
        I_cols = [f"I_{k}" for k in range(K)]
        I_data = df[I_cols].values
        self.I_data = torch.tensor(I_data, dtype=torch.float32)
        self.N = len(df)

    def __len__(self):
        return self.N

    def __getitem__(self, idx):
        y = self.y_data[idx]
        yprime = self.yprime_data[idx]
        I_vec = self.I_data[idx] 
        return (y, yprime, I_vec)


def read_dataset(csv_file: str, K: int) -> PreferenceDataset:
    df = pd.read_csv(csv_file, header=0)
    dataset = PreferenceDataset(df, K)
    return dataset


class PolicyVector(nn.Module):
    def __init__(self, y_dim: int):
        super().__init__()
        self.theta = nn.Parameter(torch.zeros(y_dim))

    def forward(self, batch_size: int) -> torch.Tensor:
        row = self.theta.unsqueeze(0)          
        logits = row.expand(batch_size, -1)    
        return logits



def compute_rho_per_sample(I_p, I_q, lambda_vec, tau, N):
    exponent_factor = 1.0 / (tau * N)
    scalar_contrib = I_p + torch.matmul(I_q, lambda_vec)

    MAX_EXP = 10.0
    raw_exponents = exponent_factor * scalar_contrib
    clamped_exponents = torch.clamp(raw_exponents, -MAX_EXP, MAX_EXP)
    rho = torch.exp(clamped_exponents)
    return rho

def lower_bound_chi_objective(I_q, rho, chi_vec, eps):
    batch_size = I_q.shape[0]
    total_obj = 0.0
    for k in range(chi_vec.shape[0]):
        exponent_terms = (rho * I_q[:, k]) / torch.clamp_min(chi_vec[k], 1e-10)
        exponent_terms = torch.clamp(exponent_terms, -10.0, 10.0)
        exp_vals = torch.exp(exponent_terms)
        mean_exp = torch.mean(exp_vals)
        part_obj_k = chi_vec[k] * torch.log(mean_exp + 1e-10) + chi_vec[k]*eps
        total_obj += part_obj_k
    return -total_obj 

def estimate_q_lower_bounds(I_q, rho, chi_vec):
    LB_vec = []
    for k in range(I_q.shape[1]):
        exponent_terms = (rho * I_q[:, k]) / torch.clamp_min(chi_vec[k], 1e-10)
        exponent_terms = torch.clamp(exponent_terms, -10.0, 10.0)
        exp_vals = torch.exp(exponent_terms)
        mean_exp = torch.mean(exp_vals)
        val_k = chi_vec[k] * torch.log(mean_exp + 1e-10)
        LB_vec.append(val_k)
    return torch.stack(LB_vec, dim=0)

def constraint_lambda_loss(lambda_vec, b_vec, q_lower_bounds):
    diff = b_vec - q_lower_bounds
    return -torch.sum(lambda_vec * diff)

def policy_extraction_loss(policy_net, y_indices, rho, y_dim):
    bsz = y_indices.shape[0]
    logits = policy_net(bsz)               
    log_probs = torch.log_softmax(logits, dim=1)
    chosen_lp = log_probs[torch.arange(bsz), y_indices]
    weighted_lp = rho * chosen_lp
    return -torch.mean(weighted_lp)

def get_policy_distribution(policy_net, y_dim, device):
    logits = policy_net(1)           
    probs = torch.softmax(logits, dim=1)
    return probs.squeeze(0).detach().cpu().numpy()

def train_mopo(dataset: PreferenceDataset, config: dict):
    device = torch.device(config["device"])
    K = config["K"]
    y_dim = config["y_dim"]
    tau = config["tau"]
    eps = config["eps"]
    b_vec = torch.tensor(config["thresholds"], dtype=torch.float32, device=device)
    max_epochs = config["max_epochs"]
    batch_size = config["batch_size"]
    lr = config["learning_rate"]

    assert len(b_vec) == (K - 1), "thresholds must match K-1."

    sampler = RandomSampler(dataset, replacement=True, num_samples=len(dataset))
    loader = DataLoader(dataset, batch_size=batch_size, sampler=sampler)
    N = len(dataset)

    lambda_vec = nn.Parameter(torch.zeros(K-1, device=device), requires_grad=True)
    chi_vec = nn.Parameter(torch.ones(K-1, device=device) * 1.0, requires_grad=True)
    policy_net = PolicyVector(y_dim=y_dim).to(device)

    opt_chi    = optim.Adam([chi_vec], lr=lr)
    opt_lambda = optim.Adam([lambda_vec], lr=lr)
    opt_policy = optim.Adam(policy_net.parameters(), lr=lr)

    def clip_all_grads():
        torch.nn.utils.clip_grad_norm_([chi_vec], max_norm=5.0)
        torch.nn.utils.clip_grad_norm_([lambda_vec], max_norm=5.0)
        torch.nn.utils.clip_grad_norm_(policy_net.parameters(), max_norm=5.0)

    policy_logs = [] 

    for epoch in range(max_epochs):
        total_chi_loss = 0.0
        total_lambda_loss = 0.0
        total_policy_loss = 0.0
        n_batches = 0

        for batch in loader:
            (y_indices, yprime_indices, I_batch) = batch
            y_indices = y_indices.to(device)
            I_batch = I_batch.to(device)
            bsz = y_indices.shape[0]

            I_p = I_batch[:, K-1]
            I_q = I_batch[:, :K-1]

            rho_vals = compute_rho_per_sample(I_p, I_q, lambda_vec, tau, N=bsz)

            opt_chi.zero_grad()
            chi_loss = lower_bound_chi_objective(I_q, rho_vals, chi_vec, eps)
            chi_loss.backward()
            clip_all_grads()
            with torch.no_grad():
                chi_vec.clamp_(min=1e-6, max=15.0)
            opt_chi.step()

            with torch.no_grad():
                q_lb = estimate_q_lower_bounds(I_q, rho_vals, chi_vec)

            opt_lambda.zero_grad()
            lam_loss = constraint_lambda_loss(lambda_vec, b_vec, q_lb)
            lam_loss.backward()
            clip_all_grads()
            with torch.no_grad():
                lambda_vec.clamp_(min=0.0, max=15.0)
            opt_lambda.step()

            rho_vals_after = compute_rho_per_sample(I_p, I_q, lambda_vec, tau, N=bsz)

            opt_policy.zero_grad()
            pol_loss = policy_extraction_loss(policy_net, y_indices, rho_vals_after, y_dim)
            pol_loss.backward()
            clip_all_grads()
            opt_policy.step()

            total_chi_loss     += chi_loss.item()
            total_lambda_loss  += lam_loss.item()
            total_policy_loss  += pol_loss.item()
            n_batches += 1

        avg_chi_loss = total_chi_loss / n_batches
        avg_lambda_loss = total_lambda_loss / n_batches
        avg_policy_loss = total_policy_loss / n_batches

        if epoch % 100 == 0 or epoch == max_epochs - 1:
            dist = get_policy_distribution(policy_net, y_dim, device)
            policy_logs.append((epoch, dist.copy()))

            if(epoch%1000 == 0):
                print(f"[Epoch {epoch}/{max_epochs}] ", 
                      f"q_lb={q_lb.item():.4f}", 
                      f"lambda_loss={avg_lambda_loss:.4f}"
                      )
                print("Current policy:", dist)

    print("Training complete!")

    with open("./data/policy_logs.csv", "w") as f:
        headers = ["epoch"] + [f"p{i}" for i in range(y_dim)]
        f.write(",".join(headers) + "\n")
        for (ep, distr) in policy_logs:
            row = [str(ep)] + [str(x) for x in distr]
            f.write(",".join(row) + "\n")
    print("Saved policy logs to policy_logs.csv")

    return policy_net, lambda_vec, chi_vec

def main():
    parser = argparse.ArgumentParser(description="MOPO Implementation")
    parser.add_argument("--dataset", type=str, default="dataset.csv", 
                        help="Path to CSV dataset.")
    parser.add_argument("--epochs", type=int, default=None, 
                        help="Number of training epochs (overrides CONFIG).")
    parser.add_argument("--batch_size", type=int, default=None, 
                        help="Batch size (overrides CONFIG).")
    args = parser.parse_args()

    if args.epochs is not None:
        CONFIG["max_epochs"] = args.epochs
    if args.batch_size is not None:
        CONFIG["batch_size"] = args.batch_size

    dataset = read_dataset(args.dataset, K=CONFIG["K"])

    policy_net, lambda_vec, chi_vec = train_mopo(dataset, CONFIG)

    print("Final lambda =", lambda_vec.detach().cpu().numpy())
    print("Final chi    =", chi_vec.detach().cpu().numpy())
    print("Final policy theta =", policy_net.theta.detach().cpu().numpy())
    print("Final policy distribution =", softmax(policy_net.theta.detach().cpu().numpy()))

if __name__ == "__main__":
    main()
