import numpy as np
from moviedata import preprocess
import agent as agent_module
import os
import time


def get_reward_prob(user_idx, item_idx, U, V, theta_star):
    u = U[user_idx]
    v = V[item_idx]
    
    x = np.outer(u, v).flatten()
    
    z = np.dot(x, theta_star)

    prob = 1 / (1 + np.exp(-z))
    
    return x, prob

def get_feature(u_idx, i_idx, U, V):
    u = U[u_idx]
    v = V[i_idx]
    x = np.outer(u, v).flatten()
    return x

def calculate_expected_reward(item_indices, u_idx, theta, U, V):
    probs = []
    for i_idx in item_indices:
        x = get_feature(u_idx, i_idx, U, V)
        z = np.dot(x, theta)
        prob = 1 / (1 + np.exp(-z))
        probs.append(prob)
    
    fail_prob = np.prod([1 - p for p in probs])
    return 1 - fail_prob

def run(T, K, dim, S, seed, algorithm):

    regret_log_dir = "/home/choihyunjun/neurips2025/regretlog"
    time_log_dir = "/home/choihyunjun/neurips2025/timelog"
    
    os.makedirs(regret_log_dir, exist_ok=True)
    os.makedirs(time_log_dir, exist_ok=True)

    R, U, V, theta_star, kappa, optimal_lists = preprocess(seed, K, dim, S)

    user_ids = R.index.tolist()
    item_ids = R.columns.tolist()

    n_users = len(user_ids)
    n_items = len(item_ids)

    N = n_items

    d = dim * dim

    if hasattr(agent_module, algorithm):
        AgentClass = getattr(agent_module, algorithm)
        agent = AgentClass(d, T, K, S, kappa)
    else:
        raise ValueError(f"Algorithm '{algorithm}' not found in agent.py")

    cumulative_regret = 0
    regret_history = []
    
    time_history = []

    rng_user = np.random.RandomState(seed+3) 
    rng_click = np.random.RandomState(seed+4) 

    print(f"Simulation Start (T={T})...")

    start_time = time.time()

    for t in range(T):
        u_idx = rng_user.randint(n_users)
        
        candidates = np.arange(n_items) 
        selected_items, selected_features = agent.select_arms(U, V, u_idx, candidates)
        
        obs_features = []
        obs_rewards = []
        
        clicked = False
        
        for i, item_idx in enumerate(selected_items):
            x_feat, true_prob = get_reward_prob(u_idx, item_idx, U, V, theta_star) 
            
            is_click = 1 if rng_click.rand() < true_prob else 0
            
            obs_features.append(x_feat)
            obs_rewards.append(is_click)
            
            if is_click == 1:
                clicked = True
                break 
                
        opt_reward = calculate_expected_reward(optimal_lists[u_idx], u_idx, theta_star, U, V)
        agent_reward = calculate_expected_reward(selected_items, u_idx, theta_star, U, V)
        
        instant_regret = opt_reward - agent_reward
        cumulative_regret += instant_regret
        regret_history.append(cumulative_regret)
        
        agent.update(np.array(obs_features), obs_rewards)

        current_time = time.time()
        elapsed_time = current_time - start_time
        time_history.append(elapsed_time)
        
        if t % 1000 == 0:
            print(f"Round {t}: Regret {instant_regret:.4f} | CumRegret {cumulative_regret:.2f} | Clicked? {clicked}")
            theta_norm = np.linalg.norm(agent.theta)
            print(f"Round {t}: Regret {cumulative_regret} | Theta Norm: {theta_norm:.2f}")
    
    filename = f"{K}_{seed}_{algorithm}"

    regret_path = os.path.join(regret_log_dir, filename)
    np.save(regret_path, np.array(regret_history))
    print(f"Saved regret history to {regret_path}.npy")

    time_path = os.path.join(time_log_dir, filename)
    np.save(time_path, np.array(time_history))
    print(f"Saved time history to {time_path}.npy")