import os
import warnings
import sys
import torch
import argparse
import numpy as np
import pandas as pd
from tqdm import tqdm
from selector import *
from online_update import *

# --------------------
# Utility Functions
# --------------------
def break_into_pairs(ranking):
    return [(ranking[i], ranking[j]) for i in range(len(ranking)) for j in range(i + 1, len(ranking))]

# --------------------
# Main Experiment
# --------------------
def run_experiment(dataset_name, num_epochs, output_path, id, loss, assortment_size, seed, info):
    
    B = 3. # true theta l2 norm boundedness
    K = assortment_size
    C = 5000 # cardinality of X
    
    evaluation_frequency = 500
    
    ### feature ###
    if dataset_name == 'msmarco':
        embeddings_np = np.load(os.path.join(output_path, f'{dataset_name}_embedded_5000.npy')) 
    elif dataset_name == 'nectar':
        embeddings_np = np.load(os.path.join(output_path, f'{dataset_name}_embedded_10000.npy'))
    
    embeddings_np = embeddings_np[:C, :, :]
    embeddings = torch.from_numpy(embeddings_np).float().to(device)
    
    l1_norms = embeddings.norm(p=1, dim=2)  # Shape: (n, N)
    max_l1_norm = l1_norms.max()
    max_l1_norm = torch.clamp(max_l1_norm, min=1e-8)  # Prevent division by zero

    embeddings = embeddings / max_l1_norm
    ################
        
    ### mistral score ###
    if dataset_name == 'msmarco':
        mistral_score = np.load(os.path.join(output_path, f'{dataset_name}_mistral_score_5000.npy'))
        
    elif dataset_name == 'nectar':
        mistral_score = np.load(os.path.join(output_path, f'{dataset_name}_mistral_score_10000.npy'))
    
    mistral_score = mistral_score[0:C, :]
    mistral_score = torch.from_numpy(mistral_score).float().to(device)
    ################
    
    assert embeddings_np.shape[0] == mistral_score.shape[0]
    del embeddings_np
    # del df
    
    embedding_dim = int(embeddings.shape[2])
    n = int(embeddings.shape[0])
 
    ### initilize ###
    if id == "dope":
        V_inv_dict = {c: torch.eye(embedding_dim, dtype=torch.float32) for c in range(n)}
        theta_dict = {}
        for c in range(n):
            epsilon = 1.0
            target_value = 1.0 / torch.sqrt(torch.tensor(embedding_dim, dtype=torch.float32))
            theta_values = torch.empty(embedding_dim).uniform_(
                target_value - epsilon, target_value + epsilon
                )
            theta = theta_values / theta_values.norm(p=2)
            cur_theta = theta.to(device)
            theta_dict[c] = theta
            
        H_dict = {c: (6 * torch.sqrt(torch.tensor(2.0)) * (1 + 3 * torch.sqrt(torch.tensor(2.0))) * torch.eye(embedding_dim, dtype=torch.float32)) for c in range(n)}
    else:
        H = (6 * torch.sqrt(torch.tensor(2.0)) * (1 + 3 * torch.sqrt(torch.tensor(2.0))) * torch.eye(embedding_dim)).to(device)
        ### initialize theta
        cur_theta = torch.ones(embedding_dim, dtype=torch.float32, device=device)
        cur_theta = cur_theta / cur_theta.norm(p=2) 
        
    eta = ((1 + 3 * torch.sqrt(torch.tensor(2.0)) * B) / 2).to(device)
    ################
    
    simple_regret_ls = []
    selected_size_ls = [] 

    for t in tqdm(range(num_epochs)):

        ####### Exponential Distribution (Favor Smaller Indices) ########
        lambda_param = 0.1  # Controls decay speed
        exp_dist = torch.distributions.Exponential(rate=lambda_param)
        index = exp_dist.sample((1,)).clamp(max=n-1).long()
        ##################################################################

        X = embeddings[index].squeeze(0) # (N, d)
        N = X.shape[0]
        true_score = mistral_score[index].squeeze(0)
        
        if id == 'uniform':
            selected_indices = torch.randperm(X.shape[0], device=X.device)[:K]
        elif id == 'm_aupo':
            # selected_indices = choose_S(X, H, K)
            selected_indices = choose_S_rand_ref(X, H, K)
        elif id == "dope":
            V_inv = V_inv_dict[index.item()].to(device)
            cur_theta = theta_dict[index.item()].to(device)
            
            H = H_dict[index.item()].to(device)
            z_dict = {(j, k): X[j] - X[k] for j in range(N) for k in range(j + 1, N)}
            V_inv, S = choose_S_dopewolfe(X, K, V_inv, z_dict, device=device)
            V_inv_dict[index.item()] = V_inv.cpu()
            selected_indices = S
        else:
            warnings.warn(f"[❌ ERROR] Unrecognized agent type: {id}. Must be 'uniform' or 'm_aupo' or 'dope'.", category=UserWarning)
            sys.exit(1)
        
        if not torch.is_tensor(selected_indices):
            selected_indices = torch.tensor(selected_indices, device=true_score.device)
        else:
            selected_indices = selected_indices.to(device=true_score.device)
            
        subset_scores = true_score[selected_indices]
        order = torch.argsort(-subset_scores)
        ranking = selected_indices[order]
        # print(selected_indices, subset_scores, order, ranking)

        if loss == 'pl':
            cur_theta = online_update_pl(cur_theta, ranking, X, H, eta, B=B)
            H += pl_hessian(cur_theta, X, ranking) 
        elif loss == 'rb':
            pairs = break_into_pairs(ranking)
            cur_theta = online_update_rb(cur_theta, pairs, X, H, eta, B=B)
            H += rb_hessian(cur_theta, X, pairs)          
        else:
            warnings.warn(f"[❌ ERROR] Unrecognized loss type: {loss}. Must be 'pl' or 'rb'.", category=UserWarning)
            sys.exit(1)
        
        if id == "dope":
            theta_dict[index.item()] = cur_theta.cpu()
            H_dict[index.item()] = H.cpu()
        
        # save simple regret
        means = (X @ cur_theta).squeeze() 
        current_best_index = torch.topk(means, k=1, largest=True).indices
        opt_val = np.max(true_score.detach().cpu().numpy())
        # opt_arm = np.argmax(true_score.detach().cpu().numpy())
        pl_val = (true_score[current_best_index]).detach().cpu().numpy()
        simple_regret = opt_val- pl_val
        simple_regret_ls.append(simple_regret)

        # save S_t size
        selected_size_ls.append(len(selected_indices))
        
        # monitoring
        if (t + 1) % evaluation_frequency == 0:
            with torch.no_grad():
                avg_simple_regret = np.mean(simple_regret_ls)
                print(f"{id}_K={K} @ Round {t+1}] Simple Regret: {avg_simple_regret:.2f}")
                
    # Ensure output subdirectories exist
    simple_regret_dir = os.path.join(output_path, 'simple_regret')
    assortment_size_dir = os.path.join(output_path, 'assortment_size')

    os.makedirs(simple_regret_dir, exist_ok=True)
    os.makedirs(assortment_size_dir, exist_ok=True)

    # Save simple regret history
    np.savetxt(
        os.path.join(simple_regret_dir, f'simple_regret_{dataset_name}_{id}_K={assortment_size}_epochs={num_epochs}_loss={loss}_{info}_{seed}.npy'),
        np.array(simple_regret_ls)
    )

    # Save S_t size history
    np.savetxt(
        os.path.join(assortment_size_dir, f'assortment_size_{dataset_name}_{id}_K={assortment_size}_epochs={num_epochs}_loss={loss}_{info}_{seed}.npy'),
        np.array(selected_size_ls)
    )
        

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-d', '--dataset', default='msmarco', type=str, help='Name of the dataset. Options: [\'msmarco\', \'nectar\']') 
    parser.add_argument('-o', '--output', default='.', type=str, help='Output folder')
    parser.add_argument('-e', '--epochs', default=20000, type=int, help='Number of epochs')
    parser.add_argument('-i', '--id', type=str, default='m_aupo', help='Options: [\'uniform\', \'m_aupo\', \'dope\'] ')
    parser.add_argument('-l', '--loss', type=str, default='pl', help='Options: [\'pl\', \'rb\'] ')
    parser.add_argument('-k', '--assortment_size', default=2, type=int, help='Maximum assortment size')    
    parser.add_argument('-s', '--seed', default=41, type=int, help='Random seed')    
    parser.add_argument('-f', '--info', type=str, default='', help='Additional text to be added at the end of filenames outputted.')

    args = parser.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("device:", device)

    seed = args.seed + 41
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    
    run_experiment(args.dataset, args.epochs, args.output, args.id, args.loss, args.assortment_size, args.seed, args.info)


    