import numpy as np
from envs.bandit_envs import *
from tqdm import trange, tqdm


from policy.bandit_policy import RandomMixingPolicy, TompsonSamplingPolicy
from policy.dueling_bandit_policy import RandomDuelingPolicy, DTS

import torch
from torch import nn
from torch.utils.data import Dataset

tsig = torch.sigmoid
def sigmoid(x): # Stable Implementation of Sigmoid Function
    return np.exp(-np.logaddexp(0, -x))

def BT(r_1,r_2):
    p = sigmoid(r_1-r_2)
    win = np.random.choice([0,1],p=[p,1-p])
    return win

class BanditDataset(Dataset):
    def __init__(self, dsets, actions, coefficients, num_A=5, horizon=500,shuffle=True):
        self.dsets = dsets
        self.actions = actions
        self.coefficients = coefficients

        self.num_A = num_A
        self.horizon = horizon
        self.shuffle=shuffle
        
        self.first = np.zeros((1, num_A + 3), dtype=np.float32)
        self.first[0, 0] = 1
    
    def __len__(self):
        return len(self.dsets)

    def __getitem__(self, idx):
        if self.shuffle:
            perm = np.random.permutation(self.horizon) # shuffled in-context datset to reduce overfitting
        else:
            perm = np.arange(len(self.dsets[idx]))
        sample_ds = np.concatenate((self.first, self.dsets[idx][perm]))
        return sample_ds, self.actions[idx][perm], self.coefficients[idx][perm]
    
class DuelingBanditDataset(Dataset):
    def __init__(self, dsets, actions, num_A=5, horizon=500,shuffle=True):
        self.dsets = dsets
        self.actions = actions

        self.num_A = num_A
        self.horizon = horizon
        self.shuffle=shuffle
        
        self.first = np.zeros((1, 2*num_A + 2), dtype=np.float32)
        self.first[0, 0] = 1
    
    def __len__(self):
        return len(self.dsets)

    def __getitem__(self, idx):
        if self.shuffle:
            perm = np.random.permutation(self.horizon) # shuffled in-context datset to reduce overfitting
        else:
            perm = np.arange(len(self.dsets[idx]))
        sample_ds = np.concatenate((self.first, self.dsets[idx][:self.horizon][perm]))
        return sample_ds, self.actions[idx][:self.horizon][0][perm],self.actions[idx][:self.horizon][1][perm]

class OptimalBanditDataset(Dataset):
    def __init__(self, dsets, actions, coefficients, true_mus, num_A=5, horizon=500,shuffle=True):
        self.dsets = dsets
        self.actions = actions
        self.coefficients = coefficients
        self.true_mus = true_mus

        self.num_A = num_A
        self.horizon = horizon
        self.shuffle=shuffle
        
        self.first = np.zeros((1, num_A + 3), dtype=np.float32)
        self.first[0, 0] = 1
    
    def __len__(self):
        return len(self.dsets)

    def __getitem__(self, idx):
        if self.shuffle:
            perm = np.random.permutation(self.horizon) # shuffled in-context datset to reduce overfitting
        else:
            perm = np.arange(len(self.dsets[idx]))
        sample_ds = np.concatenate((self.first, self.dsets[idx][perm]))
        return sample_ds, np.argmax(self.true_mus[idx]).repeat(self.horizon), self.coefficients[idx][perm]

# Sampling Functions for Contexts/Envs
def sample_MultiArmedBandits(num_A=5, min_reward=0, max_reward=1):
    '''
    Returns the mean reward vector for bandits 
    '''
    # return min_reward + np.random.rand(num_A)*(min_reward - max_reward)

    d = num_A
    return np.random.normal(0, np.sqrt(1/d), size=d)

def sample_LinearBandits(d=2):
    '''
    Returns the environmental vector theta for Linear bandits
    '''
    return np.random.normal(0, np.sqrt(1/d), size=d)

# Parallel version
# from eval import deploy_tf_collect_multi_env
def deploy_tf_collect_multi_env(envs, policy, horizon, device, num_A=5):
    num_envs = len(envs)
    X = torch.zeros((num_envs, horizon + 1, num_A + 3), device=device)
    X[:, :, 0] = 1
    X[:, :, -2] = 1
    
    Rs = []
    As = []
    with torch.no_grad():
        for h in range(horizon):
            
            # prediction = torch.softmax(policy(X)[:,h], -1).cpu().numpy()
            # a = [np.random.choice(num_A, p=prediction[i]) for i in range(num_envs)]
            prediction = torch.softmax(policy(X)[:,h], -1)
            a = torch.multinomial(prediction,num_samples=1).squeeze(-1)
            # a = torch.argmax(prediction,-1).squeeze(-1)
            
            r = [env.transit(a[i]) for (i,env) in enumerate(envs)]
            As.append(a.cpu().numpy())
            Rs.append(r)
            
            # X[:,h+1,a+1]=1
            X[torch.arange(num_envs),h+1,a+1] = 1
            X[:, h + 1, -1] = torch.tensor(r).cuda()
    As = np.array(As).T
    Rs = np.array(Rs).T
    return As, Rs

def generate_LB_TF_multi_env(model, device, phi, batch_env = 128, num_A=10, d=2, N=40000, horizon=200, var=0.3):
    
    dsets, actions, ora_coefficients, coeffs = [], [], [], []
    true_mus = []

    for _ in tqdm(range(N//batch_env)):
        
        envs = [LinearBandits(sample_LinearBandits(d=d), phi, var=var) for i in range(batch_env)]
        
        As, Rs = deploy_tf_collect_multi_env(envs, policy = model, horizon=horizon, device=device, num_A=num_A)
        for i in range(batch_env):
            a = As[i]
            actions.append(a)
            r = Rs[i]
    
            mus = np.zeros(num_A)
            for j in range(num_A):
                if len(r[a==j]) > 0:
                    mus[j] = np.mean(r[a==j])
            # coeffs.append(np.exp(mus[a] - np.array(r).mean()))
            coeffs.append(np.maximum(mus[a] - np.array(r).mean(),0))
    
            true_mu = envs[i].mu
            true_mus.append(true_mu)
            ora_coefficients.append(np.maximum(true_mu[a] - true_mu.mean(), 0))
    
            a_one_hot = np.zeros((horizon, num_A))
            a_one_hot[np.arange(horizon), a] = 1
    
            X = np.zeros((horizon, num_A + 3), np.float32)
            X[:, 0] = 1
            X[:, 1:num_A + 1] = a_one_hot
            X[:, -2] = 1
            X[:, -1] = r
            dsets.append(X)

    return dsets, actions, coeffs, ora_coefficients, true_mus

# Non-parallel version; too slow
def generate_LB_TF(model, device, phi, num_A=10, d=2, N=40000, horizon=200, var=0.3):
    
    dsets, actions, ora_coefficients, coeffs = [], [], [], []
    true_mus = []

    for _ in tqdm(range(N)):
        # generate a linear bandit instance
        env_parameters = sample_LinearBandits(d=d)
        env = LinearBandits(env_parameters, phi, var=var)
        true_mus.append(env.mu)
        
        As, Rs = deploy_tf_collect(env, model, horizon, device, num_A)
        
        a = np.array(As)
        actions.append(a)
        r = np.array(Rs)

        mus = np.zeros(num_A)
        for i in range(num_A):
            if len(r[a==i]) > 0:
                mus[i] = np.mean(r[a==i])
        # coeffs.append(np.exp(mus[a] - np.array(r).mean()))
        coeffs.append(np.maximum(mus[a] - np.array(r).mean(),0))

        true_mu = true_mus[-1]
        ora_coefficients.append(np.maximum(true_mu[a] - true_mu.mean(), 0))

        a_one_hot = np.zeros((horizon, num_A))
        a_one_hot[np.arange(horizon), a] = 1

        X = np.zeros((horizon, num_A + 3), np.float32)
        X[:, 0] = 1
        X[:, 1:num_A + 1] = a_one_hot
        X[:, -2] = 1
        X[:, -1] = r
        dsets.append(X)

    return dsets, actions, coeffs, ora_coefficients, true_mus

def generate_LB(num_A=10, d=2, N=40000, horizon=200, var=0.3, cov=0.0, bp='tompson',phi=None):
    if phi is None:
        phi = np.random.normal(0, np.sqrt(1/d), size=(num_A, d))
    dsets, actions, ora_coefficients, coeffs = [], [], [], []
    true_mus = []

    for _ in tqdm(range(N)):
        # generate a linear bandit instance
        env_parameters = sample_LinearBandits(d=d)
        env = LinearBandits(env_parameters, phi, var=var)
        true_mus.append(env.mu)
        opt_a = np.argmax(true_mus[-1])
        # generate a policy
        if bp == 'tompson':
            policy = TompsonSamplingPolicy(num_A)
        elif bp =='random':
            policy = RandomMixingPolicy(num_A, opt_a = opt_a, cov=cov)
        else:
            raise NotImplemented('Go Finish the Work!')
        
        # collect traj
        As, Rs = [], []
        for h in range(horizon):
            a = policy.act(As, Rs)
            r = env.transit(a)
            policy.update(a,r)
            
            As.append(a)
            Rs.append(r)
        
        a = np.array(As)
        actions.append(a)
        r = np.array(Rs)

        mus = np.zeros(num_A)
        for i in range(num_A):
            if len(r[a==i]) > 0:
                mus[i] = np.mean(r[a==i])
        # coeffs.append(np.exp(mus[a] - np.array(r).mean()))
        coeffs.append(np.maximum(mus[a] - np.array(r).mean(),0))

        true_mu = true_mus[-1]
        ora_coefficients.append(np.maximum(true_mu[a] - true_mu.mean(), 0))

        a_one_hot = np.zeros((horizon, num_A))
        a_one_hot[np.arange(horizon), a] = 1

        X = np.zeros((horizon, num_A + 3), np.float32)
        X[:, 0] = 1
        X[:, 1:num_A + 1] = a_one_hot
        X[:, -2] = 1
        X[:, -1] = r
        dsets.append(X)

    return dsets, actions, coeffs, ora_coefficients, true_mus


def generate_MAB(num_A=5, N=80000, horizon=500, var=0.3, cov=0.0, bp='random'):
    dsets, actions, coefficients, ora_coefficients = [], [], [], []
    true_mus = []
    for i in trange(N): # Collect one trajectory 
        # generate a MAB instance 
        env_parameters = sample_MultiArmedBandits(num_A)
        env = MultiArmedBandits(env_parameters, var=var)
        true_mus.append(env.mu)
        # generate a behavioral policy
        # generate a policy
        opt_a = np.argmax(true_mus[-1])
        # generate a policy
        if bp == 'tompson':
            policy = TompsonSamplingPolicy(num_A)
        elif bp =='random':
            policy = RandomMixingPolicy(num_A, opt_a = opt_a, cov=cov)
        else:
            raise NotImplementedError('Go Finish the Work!')
        
        # collect traj
        As, Rs = [], []
        for h in range(horizon):
            a = policy.act(As, Rs)
            r = env.transit(a)
            policy.update(a,r)
            As.append(a)
            Rs.append(r)
        
        a = np.array(As)
        actions.append(a)
        r = np.array(Rs)
        

        # Computes the reweighting coefficients (expected rewards of bandits) by empirical means
        mu_empirical = np.zeros(num_A)
        for i in range(num_A):
            if np.argwhere(a==i).size > 0:
                mu_empirical[i] = r[np.where(a==i)].mean()
        coefficients.append(np.maximum(mu_empirical[a] - r.mean(), 0))
        true_mu = true_mus[-1]
        ora_coefficients.append(np.maximum(true_mu[a] - true_mu.mean(), 0))
        
        X = np.zeros((horizon, num_A + 3), np.float32)
        X[:, [0, -2]] = 1
        X[np.arange(horizon), a + 1] = 1
        X[:, -1] = r
        dsets.append(X)
    return dsets, actions, coefficients, ora_coefficients, true_mus

def generate_DB(num_A=10, d=2, N=40000, horizon=200, var=0.3, cov=0.0, bp='random',bp_mode='random',phi=None):
    if phi is None:
        phi = np.random.normal(0, np.sqrt(1/d), size=(num_A, d))
    dsets, actions, ora_coefficients, coeffs = [], [], [], []
    true_mus = []

    for _ in tqdm(range(N)):
        # generate a linear bandit instance
        env_parameters = sample_LinearBandits(d=d)
        env = LinearBandits(env_parameters, phi, var=var)
        true_mu = env.mu
        true_mus.append(true_mu)
        opt_a = np.argmax(true_mus[-1])
        if bp =='random':
            policy = RandomDuelingPolicy(num_A, opt_a = opt_a, cov=cov, mode=bp_mode)
        elif bp =='dts':
            policy = DTS(num_A)
        else:
            raise NotImplemented('Go Finish the Work!')
        
        # collect traj
        AWs, ALs = [], []
        for h in range(horizon):
            if bp == 'random':
                a_1,a_2 = policy.act(AWs, ALs)
            elif bp == 'dts':
                a_1,a_2 = policy.act(h)
            else:
                pass

            win = BT(true_mu[a_1],true_mu[a_2]) # Dueling according to BT model
            a_w = a_1 if win==0 else a_2
            a_l = a_1 if win==1 else a_2

            if bp == 'dts':
                policy.update(a_w, a_l)
            
            AWs.append(a_w)
            ALs.append(a_l)
        
        a = [np.array(AWs), np.array(ALs)]
        actions.append(a)


        aw_one_hot = np.zeros((horizon, num_A))
        aw_one_hot[np.arange(horizon), a[0]] = 1
        al_one_hot = np.zeros((horizon, num_A))
        al_one_hot[np.arange(horizon), a[1]] = 1
        

        X = np.zeros((horizon, 2*num_A + 2), np.float32)
        X[:, 0] = 1
        X[:, 1:num_A + 1] = aw_one_hot
        X[:, num_A+1] = 1
        X[:, num_A+2:] = al_one_hot
        dsets.append(X)

    return dsets, actions, true_mus