import numpy as np

from scipy import stats
import ipdb
import torch
from torch.distributions import Categorical

_eps = 1e-7

def random(mu, a, pt, temperature=1.0):
    ensembles, batch, classes = mu.shape

    # mu_probs = torch.softmax(mu, dim=2)
    # adv = (mu_probs * mu).sum(dim=-1) # elt wise mult
    # mu_in_question = torch.gather(mu, 2, a.unsqueeze(0).unsqueeze(-1).repeat(ensembles, 1, 1)).squeeze(-1)
    # mu_adv = (mu_in_question - adv).var(0)

    # return torch.ones((batch), dtype=torch.float64), mu_adv.double(), pt
    return torch.ones((batch), dtype=torch.float64), None, None

two_actions = [0, 2]

def tau(mu, a, pt, temperature=1.0):
    ensembles, batch, classes = mu.shape
    return ((mu[:, :, 0] - mu[:, :, 2]).var(0) ** (1 / temperature)).double(), None, None

    return (( ((mu[:, :, 0] - mu[:, :, 2]).var(0) ** (1 / temperature)).double() +
            ((mu[:, :, 1] - mu[:, :, 2]).var(0) ** (1 / temperature)).double() +
            ((mu[:, :, 0] - mu[:, :, 1]).var(0) ** (1 / temperature)).double() ))/3.0

def mu_f(mu, a, pt, temperature=1.0):
    mask_0 = (a==0).squeeze()
    mask_1= (a==1).squeeze()
    mask_2 = (a==2).squeeze()
    ensembles, batch, classes = mu.shape
    return ( (mu[:, :, 2].var(0) * mask_2 + mu[:, :, 1].var(0) * mask_1 + mu[:, :, 0].var(0) * mask_0) ** (1 / temperature) ).double(), None, None  

def mu_adv(mu, a, pt, temperature=1.0):
    mask_0 = (a==0).squeeze()
    mask_1= (a==1).squeeze()
    mask_2 = (a==2).squeeze()
    ensembles, batch, classes = mu.shape
    mu_avg = mu.mean(dim=2)
    mu_min_mu_avg = mu - mu_avg.unsqueeze(-1)
    return ( 
        (mu_min_mu_avg[:, :, 2].var(0) * mask_2 + \
        mu_min_mu_avg[:, :, 1].var(0) * mask_1 + \
        mu_min_mu_avg[:, :, 0].var(0) * mask_0) ** (1 / temperature)
    ).double(), None, None

def rho(mu, a, pt, temperature=1.0):
    ensembles, batch, classes = mu.shape
    print(tau(mu, a, pt, temperature))
    print(mu_f(mu, a, pt, temperature))
    return (tau(mu, a, pt, temperature) / (
        mu_f(mu, a, pt, temperature) + _eps
    )).double(), None, None

def mu_rho(mu, a, pt, temperature=1.0):
    ensembles, batch, classes = mu.shape
    return mu_f(mu, a, pt, temperature) * rho(mu, a, pt, temperature), None, None           # make this adv?

def pi(mu, a, pt, temperature=1.0):
    mask_0 = (a==0).squeeze()
    mask_1 = (a==0).squeeze()
    mask_2 = (a==2).squeeze()
    # ensembles, batch, classes = mu.shape
    # print avg propensity or something?
    metric = mask_0 * (1 - pt[:, 0]) + mask_1 * (1 - pt[:, 1]) + mask_2 * (1 - pt[:, 2])  
    # print(metric)           ## why are the elts so small?
    # print(metric.shape)
    return metric.double(), None, None           


def mu_pi(mu, a, pt, temperature=1.0):
    # ipdb.set_trace()
    # ensembles, batch, classes = mu.shape
    return mu_adv(mu, a, pt, temperature) * pi(mu, a, pt, temperature), None, None

# change above codes to excepct None in tuples

def action_diffs(mu, a, probs):
    # get difference in mu between action taken according to probs and other actions
    ensembles, batch, classes = mu.shape
    mu_in_question = torch.gather(mu, 2, a.unsqueeze(0).unsqueeze(-1).repeat(ensembles, 1, 1)).squeeze(-1)
    # inverse of gather
    # mu_others = torch.gather(mu, 2, (1 - a).unsqueeze(0).unsqueeze(-1).repeat(ensembles, 1, 1)).squeeze(-1)
    # mu_not_in_question = mu - mu_in_question
    # mu_diff = (mu_in_question - mu.mean(dim=-1))
    pass

########################################################################

def mu_pi_atari(mu, a, pt, temperature=1.0):
    resa, _, _ = mu_realadv_atari(mu, a, pt, temperature)
    resb, _, _ = pi_atari(mu, a, pt, temperature)
    return resa * resb, resa, resb

def pi_atari(mu, a, pt, temperature=1.0):
    # ensembles, batch, classes = mu.shape
    pt_in_question = torch.gather(pt, 1, a.unsqueeze(-1)).squeeze(-1)
    return (1.0 - pt_in_question).double(), None, pt_in_question        

def bald(mu, a, pt, temperature=1.0):
    ensembles, batch, classes = mu.shape
    mu_probs = torch.softmax(mu, dim=2)
    mean_probs = mu_probs.mean(dim=0)
    mean_probs = Categorical(mean_probs)
    mu_probs = Categorical(mu_probs)
    mu_adv = mean_probs.entropy() - mu_probs.entropy().mean(dim=0)
    return mu_adv.double(), mu_adv, None

def mu_probs_atari(mu, a, pt, temperature=1.0):
    # ipdb.set_trace()
    ensembles, batch, classes = mu.shape
    mu_probs = torch.softmax(mu, dim=2)

    pt_in_question = torch.gather(pt, 1, a.unsqueeze(-1)).squeeze(-1)

    mu_probs_in_question = torch.gather(mu_probs, 2, a.unsqueeze(0).unsqueeze(-1).repeat(ensembles, 1, 1)).squeeze(-1)
    mu_probs_var = mu_probs_in_question.var(0)

    return mu_probs_var.double(), mu_probs_var, pt_in_question

def mu_realadv_atari(mu, a, pt, temperature=1.0):
    ensembles, batch, classes = mu.shape
    mu_probs = torch.softmax(mu, dim=2)
    value_f = (mu_probs * mu).sum(dim=-1) # elt wise mult
    mu_in_question = torch.gather(mu, 2, a.unsqueeze(0).unsqueeze(-1).repeat(ensembles, 1, 1)).squeeze(-1)
    mu_adv = (mu_in_question - value_f).var(0)
    return mu_adv.double(), mu_adv, None

def mu_meanadv_atari(mu, a, pt, temperature=1.0):
    ensembles, batch, classes = mu.shape
    mu_probs = torch.softmax(mu, dim=2)
    value_f = (mu_probs * mu).sum(dim=-1).mean(dim=0) # elt wise mult
    mu_in_question = torch.gather(mu, 2, a.unsqueeze(0).unsqueeze(-1).repeat(ensembles, 1, 1)).squeeze(-1)
    mu_adv = (mu_in_question - value_f.unsqueeze(0).repeat(ensembles, 1)).var(0)
    return mu_adv.double(), mu_adv, None


def mu_mean_atari(mu, a, pt, temperature=1.0):
    ensembles, batch, classes = mu.shape
    mu_in_question = torch.gather(mu, 2, a.unsqueeze(0).unsqueeze(-1).repeat(ensembles, 1, 1)).squeeze(-1)
    mean_mu_in_question = mu_in_question.mean(dim=0)
    mu_adv = (mu_in_question - mean_mu_in_question).var(0)
    return mu_adv.double(), mu_adv, None

def mu_both_atari(mu, a, pt, temperature=1.0):
    ensembles, batch, classes = mu.shape
    mu_probs = torch.softmax(mu, dim=2)
    adv = (mu_probs * mu).sum(dim=-1) # elt wise mult
    mu_pred, action_pred = mu.mean(dim=0).max(dim=1)
    mu_in_question = torch.gather(mu, 2, action_pred.unsqueeze(0).unsqueeze(-1).repeat(ensembles, 1, 1)).squeeze(-1)
    mu_adv = (mu_in_question - adv).var(0)
    return mu_adv.double(), mu_adv, None

def mu_indepadv_atari(mu, a, pt, temperature=1.0):
    # ipdb.set_trace()
    ensembles, batch, classes = mu.shape
    mu_probs = torch.softmax(mu, dim=2)
    adv = (mu_probs * mu).sum(dim=-1) # elt wise mult
    mu_pred, action_pred = mu.mean(dim=0).max(dim=1)
    mu_pred = mu_pred.reshape(1, -1).repeat(ensembles, 1)
    mu_in_question = torch.gather(mu, 2, action_pred.unsqueeze(0).unsqueeze(-1).repeat(ensembles, 1, 1)).squeeze(-1)
    mu_adv = ((mu_in_question - adv)/mu_pred).var(0)
    return mu_adv.double(), mu_adv, None

def mu_combo_atari(mu, a, pt, temperature=1.0):
    ensembles, batch, classes = mu.shape
    mu_probs = torch.softmax(mu, dim=2)
    adv = (mu_probs * mu).sum(dim=-1) # elt wise mult
    mu_pred, action_pred = mu.mean(dim=0).max(dim=1)
    mu_pred = mu_pred.reshape(1, -1).repeat(ensembles, 1)
    mu_in_question_pred = torch.gather(mu, 2, action_pred.unsqueeze(0).unsqueeze(-1).repeat(ensembles, 1, 1)).squeeze(-1)
    mu_adv1 = (mu_in_question_pred - adv).var(0).double()
    mu_in_question_data = torch.gather(mu, 2, a.unsqueeze(0).unsqueeze(-1).repeat(ensembles, 1, 1)).squeeze(-1)
    mu_adv2 = (mu_in_question_data - adv).var(0).double()
    return ((mu_adv1/(mu_adv1.sum()*1.0)) + (mu_adv2/(mu_adv2.sum()*1.0))), mu_adv, None
    

def mu_max_atari(mu, a, pt, temperature=1.0):
    ensembles, batch, classes = mu.shape
    mu_probs = torch.softmax(mu, dim=2)
    adv = (mu_probs * mu).sum(dim=-1) # elt wise mult
    mu_pred, action_pred = mu.mean(dim=0).max(dim=1)
    mu_pred = mu_pred.reshape(1, -1).repeat(ensembles, 1)
    mu_in_question_pred = torch.gather(mu, 2, action_pred.unsqueeze(0).unsqueeze(-1).repeat(ensembles, 1, 1)).squeeze(-1)
    mu_adv1 = (mu_in_question_pred - adv).var(0)
    mu_in_question_data = torch.gather(mu, 2, a.unsqueeze(0).unsqueeze(-1).repeat(ensembles, 1, 1)).squeeze(-1)
    mu_adv2 = (mu_in_question_data - adv).var(0)
    return torch.max(mu_adv1, mu_adv2).double(), mu_adv, None

def mu_mult_atari(mu, a, pt, temperature=1.0):
    ensembles, batch, classes = mu.shape
    mu_probs = torch.softmax(mu, dim=2)
    adv = (mu_probs * mu).sum(dim=-1) # elt wise mult
    mu_pred, action_pred = mu.mean(dim=0).max(dim=1)
    mu_pred = mu_pred.reshape(1, -1).repeat(ensembles, 1)
    mu_in_question_pred = torch.gather(mu, 2, action_pred.unsqueeze(0).unsqueeze(-1).repeat(ensembles, 1, 1)).squeeze(-1)
    mu_adv1 = (mu_in_question_pred - adv).var(0)
    mu_in_question_data = torch.gather(mu, 2, a.unsqueeze(0).unsqueeze(-1).repeat(ensembles, 1, 1)).squeeze(-1)
    mu_adv2 = (mu_in_question_data - adv).var(0)
    return (mu_adv1 * mu_adv2).double(), mu_adv, None

def mu_bald_combo_atari(mu, a, pt, temperature=1.0):
    ensembles, batch, classes = mu.shape
    mu_probs = torch.softmax(mu, dim=2)

    # var of predicted action
    adv = (mu_probs * mu).sum(dim=-1) # elt wise mult
    mu_pred, action_pred = mu.mean(dim=0).max(dim=1)
    mu_pred = mu_pred.reshape(1, -1).repeat(ensembles, 1)
    mu_in_question = torch.gather(mu, 2, action_pred.unsqueeze(0).unsqueeze(-1).repeat(ensembles, 1, 1)).squeeze(-1)
    mu_adv = (mu_in_question - adv).var(0)

    # vbald score
    mean_probs = mu_probs.mean(dim=0)
    mean_probs = Categorical(mean_probs)
    mu_probs = Categorical(mu_probs)
    bald = mean_probs.entropy() - mu_probs.entropy().mean(dim=0)

    return (mu_adv+bald).double(), mu_adv, None


# def mu_realadv_normed_atari(mu, a, pt, temperature=1.0):
#     ensembles, batch, classes = mu.shape
#     mu_probs = torch.softmax(mu, dim=2)
#     # mu_pred, action_pred = mu.mean(dim=0).max(dim=1)

#     adv = (mu_probs * mu).sum(dim=-1) # elt wise mult
#     mu_in_question = torch.gather(mu, 2, a.unsqueeze(0).unsqueeze(-1).repeat(ensembles, 1, 1)).squeeze(-1)
#     mu_adv = (mu_in_question - adv).var(0)
#     return mu_adv.double(), mu_adv, None


# def mu_halfboth_atari(mu, a, pt, temperature=1.0):
#     # ipdb.set_trace()
#     ensembles, batch, classes = mu.shape
#     mu_probs = torch.softmax(mu, dim=2)
#     adv = (mu_probs * mu).sum(dim=-1) # elt wise mult
#     mu_pred, action_pred = mu.mean(dim=0).max(dim=1)
#     mu_pred = mu_pred.reshape(1, -1).repeat(ensembles, 1)
#     mu_in_question = torch.gather(mu, 2, action_pred.unsqueeze(0).unsqueeze(-1).repeat(ensembles, 1, 1)).squeeze(-1)

#     mu_adv = (mu_in_question/mu_pred).var(0)
#     return mu_adv.double(), mu_adv, None

# def mu_realadv_atari(mu, a, pt, temperature=1.0):
#     ensembles, batch, classes = mu.shape
#     mu_probs = torch.softmax(mu, dim=2)
#     adv = (mu_probs * mu).sum(dim=-1) # elt wise mult
#     mu_in_question = torch.gather(mu, 2, a.unsqueeze(0).unsqueeze(-1).repeat(ensembles, 1, 1)).squeeze(-1)
#     # pt_in_question = torch.gather(pt, 1, a.unsqueeze(-1)).squeeze(-1)
#     avg = mu.mean(dim=0)
#     avg_in_question = torch.gather(avg, 1, a.unsqueeze(-1))
#     mu_adv = (mu_in_question - avg_in_question.squeeze()).var(0)

    # mu_adv = (mu_in_question - adv).var(0)
    # return mu_adv.double(), mu_adv, None #pt_in_question

def mu_adv_atari(mu, a, pt, temperature=1.0):
    # ipdb.set_trace()
    ensembles, batch, classes = mu.shape
    mu_in_question = torch.gather(mu, 2, a.unsqueeze(0).unsqueeze(-1).repeat(ensembles, 1, 1)).squeeze(-1)
    pt_in_question = None # torch.gather(pt, 1, a.unsqueeze(-1)).squeeze(-1)

    mu_avg = mu.mean(dim=2)
    mu_min_mu_avg = (mu_in_question - mu_avg).var(0)
    return mu_min_mu_avg.double(), mu_min_mu_avg, pt_in_question


FUNCTIONS = {
    "random": random,
    "mu_f": mu_f,
    "mu_adv": mu_adv,
    "rho": rho,
    "mu-rho": mu_rho,
    "tau": tau,
    "pi": pi,
    "mu-pi": mu_pi,
    "mu_adv_atari": mu_adv_atari,   # var in outcome of this action
    "mu_probs_atari": mu_probs_atari,   # var in prob of this action
    "pi_atari": pi_atari,        # how unlikely was this action
    "mu-pi_atari": mu_pi_atari,
    "mu_realadv_atari": mu_realadv_atari,
    "mu_indepadv_atari": mu_indepadv_atari,
    "mu_both_atari": mu_both_atari,
    "bald": bald,
    "mu_combo_atari": mu_combo_atari,
    "mu_bald_combo_atari": mu_bald_combo_atari,
    "mu_mean_atari": mu_mean_atari,
    "mu_meanadv_atari": mu_meanadv_atari,
    'mu_max_atari': mu_max_atari,
    'mu_mult_atari': mu_mult_atari,
}



# def max_minus_min(mu, t, pt, temperature):
#     ensembles, batch, classes = mu.shape
#     return torch.ones((batch), dtype=torch.float64)

# def average(mu, t, pt, temperature):
#     ensembles, batch, classes = mu.shape
#     return torch.ones((batch), dtype=torch.float64)

# def action_minus_average(mu, t, pt, temperature):
#     ensembles, batch, classes = mu.shape
#     return torch.ones((batch), dtype=torch.float64)

# def entropy(mu, t, pt, temperature):
#     ensembles, batch, classes = mu.shape
#     # torch.swapaxes(x, 0, 1)
#     torch.var(mu, unbiased=False, dim=0) # if true, bessel's correction
#     # mean/sum across actions?
#     return
