import numpy as np
import torch
import torch.nn as nn
import os
import argparse
from bandit import Bandit
from options import Neural, Lin, KLEXP  

if not os.path.exists('regrets'):
    os.mkdir('regrets')

SEED_HIDDEN = 1000
PI = 3.14

def experiment(
    # model class
    neural_or_lin,          # 'neural' | 'lin'
    algo,              # 'UCB' | 'TS' | 'kl-exp'

    # score function
    h_str,
    noise_coef,

    # feature vector
    unif,
    n_arms,
    n_features,

    # combinatorial choices
    n_assortment,
    n_samples,

    # rounds per simulation
    total_rounds,

    # number of simulations
    n_sim,
    
    # coefficients
    reg_factor,
    delta,
    nu,
    gamma,

    # neural hyperparams (also reused by KL-EXP reward net)
    hidden_layer_width,     
    epochs,
    dropout,
    learning_rate,
    training_period,
    training_window,

    # KL-EXP specific
    eta,
    ref_policy,

    # save filename
    save
    ):

    T = total_rounds    
    p = dropout

    # device: prefer GPU for neural or kl-exp; CPU for pure linear
    if ((neural_or_lin == 'neural') or (algo == 'kl-exp')) and torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    # score function h
    np.random.seed(SEED_HIDDEN)
    tmp = np.random.randn(n_features)
    a = torch.from_numpy(tmp / np.linalg.norm(tmp, ord=2)).to(device)

    if h_str == "h1":
        def h(x): return torch.dot(x, a).to(device)
    elif h_str == "h2":
        def h(x): return (torch.dot(x, a) ** 2).to(device)
    elif h_str == "h3":
        def h(x): return torch.cos(PI * torch.dot(x, a)).to(device)
    elif h_str == "h4":
        # Teacher MLP with the SAME architecture as learner but DIFFERENT parameters (frozen)
        # Architecture: n_features -> hidden_layer_width -> 1, ReLU, bias=False
        class _TeacherNet(nn.Module):
            def __init__(self, d, hidden):
                super().__init__()
                self.fc1 = nn.Linear(d, hidden, bias=False)
                self.act = nn.ReLU()
                self.fc2 = nn.Linear(hidden, 1, bias=False)
            def forward(self, x):
                return self.fc2(self.act(self.fc1(x)))
        # Use a different seed to ensure different parameters from learner
        torch.manual_seed(SEED_HIDDEN)
        teacher = _TeacherNet(n_features, hidden_layer_width).to(device).eval()
        for p_ in teacher.parameters():
            p_.requires_grad_(False)

        @torch.no_grad()
        def h(x):
            # x: (d,) -> output scalar tensor on device
            return teacher(x.reshape(1, -1).float()).squeeze()
        raise ValueError(f"Unknown score function: {h_str}")

    # round reward aggregator
    def F(x):
        if x.dim == 1:
            return torch.sum(x)
        else:
            return torch.sum(x, dim=-1)

    # bandit
    bandit = Bandit(
        T,
        n_arms,
        n_features, 
        h,
        noise_coef=noise_coef,
        n_assortment=n_assortment,
        n_samples=n_samples,
        round_reward_function=F,
        device=device,
        n_sim=n_sim,
        unif=unif
    )

    regrets = np.empty((n_sim, T))

    for i in range(n_sim):
        bandit.reset(i)

        if algo == 'kl-exp':
            # KL-regularized exponential-weights policy
            model = KLEXP(
                bandit=bandit,
                eta=eta,
                ref_policy=ref_policy,                   # 'uniform' 
                reward_net_hidden=hidden_layer_width,    # reuse neural args
                reward_net_layers=2,
                reward_dropout=dropout,
                reward_lr=learning_rate,
                reward_weight_decay=0.0,
                training_period=training_period,
                training_window=training_window,
                epochs=epochs,
                device=device,
                throttle=100
            )

        else:
            # UCB / TS with Neural or Linear approximator
            if neural_or_lin == 'neural':
                model = Neural(
                    algo,                            # 'UCB' or 'TS'
                    bandit,
                    hidden_layer_width,
                    reg_factor=reg_factor,
                    delta=delta,
                    gamma=gamma,
                    nu=nu,
                    p=p,
                    training_window=training_window,
                    learning_rate=learning_rate,
                    epochs=epochs,
                    training_period=training_period,
                    device=device
                )
                model.set_init_param(model.model.parameters())

            elif neural_or_lin == 'lin':
                model = Lin(
                    algo,                            # 'UCB' or 'TS'
                    bandit,
                    reg_factor=reg_factor,
                    delta=delta,
                    gamma=gamma,
                    nu=nu,
                    device=device
                )
            else:
                raise ValueError("neural_or_lin must be one of {'neural','lin'}")

        model.run()
        regrets[i] = np.cumsum(model.regrets.to('cpu').numpy())

    if save:
        np.save('regrets/' + save, regrets)
    return regrets

#-------------------

if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    # model class
    parser.add_argument('--neural_or_lin', type=str, default='neural',
                        choices=['neural', 'lin'],
                        help="'neural' or 'lin' (ignored when --algo=kl-exp)")
    # algorithm / scheme
    parser.add_argument('--algo', type=str, default='UCB',
                        choices=['UCB', 'TS', 'kl-exp'],
                        help="'UCB' | 'TS' | 'kl-exp'")

    # score function
    parser.add_argument('--score_ftn', type=str, default='h1',
                        choices=['h1','h2','h3','h4'],   
                        help="h1 (linear), h2 (quadratic), h3 (cosine), h4 (teacher MLP)")

    parser.add_argument('--noise_coef', type=float, default=0.01)

    # feature vector
    parser.add_argument('--unif', type=str, default='False')
    parser.add_argument('--n_arms', type=int, default=20)
    parser.add_argument('--n_features', type=int, default=80)
    
    # combinatorial selection / multiple sampling
    parser.add_argument('--n_assortment', type=int, default=1)
    parser.add_argument('--n_samples', type=int, default=1)
    
    # rounds per simulation
    parser.add_argument('--total_rounds', type=int, default=2000)

    # number of simulations
    parser.add_argument('--n_sim', type=int, default=20)

    # coefficients
    parser.add_argument('--reg_factor', type=float, default=1.0)
    parser.add_argument('--delta', type=float, default=0.1)
    parser.add_argument('--nu', type=float, default=1.0)
    parser.add_argument('--gamma', type=float, default=1.0)

    # neural hyperparams (also reused by KL-EXP reward net)
    parser.add_argument('--hidden_layer_width', type=int, default=100)
    parser.add_argument('--epochs', type=int, default=100)
    parser.add_argument('--dropout', type=float, default=0.0)
    parser.add_argument('--learning_rate', type=float, default=0.01)
    parser.add_argument('--training_period', type=int, default=10)
    parser.add_argument('--training_window', type=int, default=100)

    # KL-EXP specific
    parser.add_argument('--eta', type=float, default=1.0,
                        help="temperature/scale for exp(eta * R_hat)")
    parser.add_argument('--ref_policy', type=str, default='uniform',
                        choices=['uniform'],
                        help="reference policy pi_ref")

    # output filename
    parser.add_argument('--save', type=str, default='')
    
    args = parser.parse_args()
    unif = True if (args.unif == 'True') else False
    
    experiment(
        # model class
        args.neural_or_lin,
        args.algo,
        
        # score function
        args.score_ftn,
        args.noise_coef,
        
        # feature vector
        unif,
        args.n_arms,
        args.n_features, 

        # combinatorial choices
        args.n_assortment,
        args.n_samples,
        
        # rounds per simulation
        args.total_rounds,

        # number of simulations
        args.n_sim,

        # coefficients
        args.reg_factor,
        args.delta,
        args.nu,
        args.gamma,

        # neural hyperparams (reused for KL-EXP reward net)
        args.hidden_layer_width,
        args.epochs,
        args.dropout,
        args.learning_rate,
        args.training_period,
        args.training_window,

        # KL-EXP specific
        args.eta,
        args.ref_policy,

        # save
        args.save
    )
