import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader

from data_generation import sample_LinearBandits
from envs.bandit_envs import *

from policy.dueling_bandit_policy import DTS

tsig = torch.sigmoid
def sigmoid(x): # Stable Implementation of Sigmoid Function
    return np.exp(-np.logaddexp(0, -x))

def BT(r_1,r_2):
    p = sigmoid(r_1-r_2)
    win = np.random.choice([0,1],p=[p,1-p])
    return win

def dueling_regret(mu,a):
    return sigmoid(max(mu)-mu[a])-0.5

def llm_act(prediction,num_A,mode='sample'):

    if mode == 'max':
         a_1, a_2 = np.argsort(prediction)[::-1][:2]
         return (a_1,a_2)

    pairs = []
    for i in range(num_A):
        for j in range(num_A):
            if i != j:
                pairs.append((i,j))
    num_pairs = len(pairs)
    
    if mode == 'sample':
        a_1 = np.random.choice(num_A, p=prediction)
        a_2 = np.random.choice(num_A, p=prediction)
    elif mode == 'random':
        a_1, a_2 = pairs[np.random.choice(num_pairs)]
    else:
        pass
        
    
    return (a_1,a_2)

def deploy_DB(env, policy, horizon):
    strong_regs, weak_regs = [], []
    true_mus = env.mu
    for h in range(horizon):
        a_1, a_2 = policy.act(h)
        r_1 = dueling_regret(true_mus,a_1)
        r_2 = dueling_regret(true_mus,a_2)
        weak_regs.append(min(r_1,r_2))
        strong_regs.append(r_1+r_2)

        win=BT(true_mus[a_1],true_mus[a_2])
        aw = a_1 if win == 0 else a_2
        al = a_1 if win == 1 else a_2

        policy.update(aw,al)
    
    return (weak_regs,strong_regs)

def deploy_DB_offline(policy,aw_list, al_list, true_mus):
    strong_regs, weak_regs = [], []

    for h in range(len(aw_list)):
        policy.reset()
        policy.batch_update(aw_list[:h], al_list[:h])
        a_1, a_2 = policy.act(h)
        
        r_1 = dueling_regret(true_mus,a_1)
        r_2 = dueling_regret(true_mus,a_2)
        weak_regs.append(min(r_1,r_2))
        strong_regs.append(r_1+r_2)

    return (weak_regs,strong_regs)

def deploy_tf_DB(env, policy, horizon, device, mode, explore_horizon=50, num_A=5):
    X = torch.zeros((1, horizon + 1, 2*num_A + 2), device=device)
    X[0, :, 0] = 1
    X[0, :, num_A+1] = 1

    strong_regs = []
    weak_regs = []
    true_mus = env.mu
    
    for h in range(horizon):
        with torch.no_grad():
            # prediction = torch.softmax(policy(X)[0,h], -1).cpu().numpy()
            probs = torch.softmax(policy(X), -1).cpu().numpy()
            prediction = probs[0,h]
        
        if h < explore_horizon:
            a_1, a_2 = llm_act(prediction,num_A,mode='random')
        else: 
            if mode == 'offline':
                # with torch.no_grad():
                #     prediction = torch.softmax(policy(X)[0,explore_horizon], -1).cpu().numpy()
                prediction = probs[0,explore_horizon]
                a_1,a_2 = llm_act(prediction,num_A,'max')
            else:
                a_1,a_2 = llm_act(prediction,num_A,mode)
        
        r_1 = dueling_regret(true_mus,a_1)
        r_2 = dueling_regret(true_mus,a_2)
        weak_regs.append(min(r_1,r_2))
        strong_regs.append(r_1+r_2)

        win=BT(true_mus[a_1],true_mus[a_2])
        aw = a_1 if win == 0 else a_2
        al = a_1 if win == 1 else a_2
        
        X[0, h + 1, aw + 1] = 1
        X[0, h + 1, num_A+2+al] = 1
    return (weak_regs,strong_regs)

def eval_offline_DB_p(actions,true_mus,num_A=5,alpha=1,pbar=None):
    regs={'DTS':[]}
    if pbar is None:
        pbar = range
    for i in pbar(len(actions)):
        aw_list, al_list = actions[i]
        policy = DTS(num_A,alpha=alpha)
        reg = deploy_DB_offline(policy,aw_list, al_list, true_mus[i])
        regs['DTS'].append(reg)
    return regs

def eval_online_DB_p(phi, d=2, num_A=5, alpha=1, trials = 20, horizon=200,pbar=None):
    regs = {'DTS':[]}
    if pbar is None:
        pbar = range
    for _ in pbar(trials):
        env_parameters = sample_LinearBandits(d=d)
        env = LinearBandits(env_parameters,phi,var=0.3)
        policy = DTS(num_A,alpha=alpha)
        reg = deploy_DB(env, policy, horizon)
        regs['DTS'].append(reg)
    return regs

## Online Evaluation
def eval_online_DB(model_dict, device, phi, mode='sample', d=2, trials = 20, horizon=200, explore_horizon=50, num_A=5,compare=True, pbar=None):
    # Initializing regrets records
   
    regs = dict()
    for model_name in model_dict:
        regs[model_name] = []

    if pbar is None:
        pbar = range
    for _ in pbar(trials):
        env_parameters = sample_LinearBandits(d=d)
        env = LinearBandits(env_parameters,phi,var=0.3)
        for model_name in model_dict:
            model = model_dict[model_name]
            # reg consists of (weak_regret, strong_regret): each is a regret list with lenght = horizon 
            reg = deploy_tf_DB(env,model,horizon, num_A = num_A, device=device, mode=mode, explore_horizon=explore_horizon)
            regs[model_name].append(reg)
    return regs

def dueling_regret_batch(mu,a):
    '''
    Batch-version function for evaluating dueling regret
    Input:
    mu -> tensor of shape: (num_env x num_a)
    a -> tensor of shape: (num_env x horizon)
    Output: regret for all actions
    r -> tensor of shape: (num_env x horizon)
    '''
    num_env, horizon = a.shape
    max_reward = torch.max(mu,dim=-1).values
    max_reward = max_reward.unsqueeze(-1).repeat(1,horizon)

    regret = tsig(max_reward - torch.gather(mu,dim=-1,index=a))-0.5

    return regret

## Offline Evaluation
def eval_offline_DB(model_dict, device, test_data, Ms,pbar=range, batch_size=16, mode='max'):
    test_dataloader = DataLoader(test_data, batch_size=batch_size,shuffle=False)
    Ms = np.array(Ms)
    regs= dict()
    
    weak_regret = []
    strong_regret = []
    for batch, (X, AW, AL) in enumerate(test_dataloader):
        X, AW, AL = X.to(device), AW.to(device), AL.to(device)
        mu = torch.tensor(Ms[batch_size*batch:min(batch_size*(batch+1),len(test_data))]).to(device)
        reg_1 = dueling_regret_batch(mu,AW)
        reg_2 = dueling_regret_batch(mu,AL)
        weak_regret.append(torch.min(reg_1,reg_2).detach().cpu().numpy())
        strong_regret.append((reg_1 + reg_2).detach().cpu().numpy())
    
    
    regs['be']=(np.concatenate(weak_regret,axis=0),np.concatenate(strong_regret,axis=0))

    # Running Deep Learning Policies
    for model_name in model_dict:
        regs[model_name] = []
    for model_name in model_dict:
        model = model_dict[model_name]

        
        with torch.no_grad():
            weak_regret = []
            strong_regret = []
            for batch, (X, AW, AL) in enumerate(test_dataloader):
                X, AW, AL = X.to(device), AW.to(device), AL.to(device)
                pred = model(X)[:,:-1]
                num_env, horizon = pred.shape[0],pred.shape[1]
                # generate batch LLM actions: num_env x horizon x 2
                if mode == 'sample':
                    actions = torch.multinomial(torch.softmax(pred.flatten(0,1),dim=-1),2).reshape(num_env,horizon,-1)
                if mode == 'max':
                    actions = torch.topk(pred,dim=-1,k=2).indices
                # evaluate regret
                mu = torch.tensor(Ms[batch_size*batch:min(batch_size*(batch+1),len(test_data))]).to(device)
                reg_1 = dueling_regret_batch(mu,actions[:,:,0].to(device))
                reg_2 = dueling_regret_batch(mu,actions[:,:,1].to(device))
                weak_regret.append(torch.min(reg_1,reg_2).detach().cpu().numpy())
                strong_regret.append((reg_1 + reg_2).detach().cpu().numpy())

        regs[model_name] = (np.concatenate(weak_regret,axis=0),np.concatenate(strong_regret,axis=0))
        
    return regs