import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, SGD
from torch.optim.lr_scheduler import StepLR
from torch import distributions
import numpy as np
import sys
from src.utils import *
from src.q20game import Q20Env
from src.hyperparam import get_param

enable_plt = False
if enable_plt:
    import matplotlib.pyplot as plt


def symmetric_difference(Z):  # 500, 1000
    Z = Z.transpose(0, 1) * 2 - 1  # 1000, 500  +1 -1
    Z1 = Z.unsqueeze(0)  # 1, 1000, 500
    Z2 = Z.unsqueeze(1)  # 1000, 1, 500
    Z_sym_diff = (Z1 * Z2 - 1) / (-2.)  # 1000, 1000, 500
    Z_sym_diff_ct = torch.sum(Z_sym_diff, dim=-1)  # 1000, 1000
    return Z_sym_diff, Z_sym_diff_ct


game_mode = sys.argv[1]
assert game_mode in ["1d", "1d100", "q20", "jester"], "game mode is either 1d or 1d100 or q20 or jester"
IT_THRESH, N_it, N_dim, N_budget, Z, LAMBDA_ENTROPY, LAMBDA_ENTROPY_BINARY, LAMBDA_SAMPLE_ENTROPY, M, L, N_GEN, \
R_lim, N_it_overall, N_pulls, eval_range = get_param(game_mode)
N_arms = Z.shape[1]
Z_sym_diff, Z_sym_diff_ct = symmetric_difference(Z)


class PolicyNet(nn.Module):
    def __init__(self):
        super(PolicyNet, self).__init__()
        self.l1 = nn.Linear(N_dim * 3, 256)
        self.l2 = nn.Linear(256, 256)
        self.l3 = nn.Linear(256, 256)
        self.l4 = nn.Linear(256, 256)
        self.l5 = nn.Linear(256, N_dim)
        weight_initialize_nn(self, sigma=.01)

    def forward(self, X):
        X = F.leaky_relu(self.l1(X))
        X = F.leaky_relu(self.l2(X))
        X = F.leaky_relu(self.l3(X))
        X = F.leaky_relu(self.l4(X))
        p = self.l5(X)
        return p


class Sampler(nn.Module):
    def __init__(self, num=N_dim):
        super(Sampler, self).__init__()
        self.weight = nn.Parameter(torch.zeros(num + 1).float().cuda())

    def forward(self):
        return self.weight


def c_fn(theta, accumulated_theta=None, target_l=None):  # theta: N * N_dim
    if accumulated_theta is None:
        accumulated_theta = torch.mm(theta, Z)  # dot product between each theta and each Z  N * N_arms
    theta_star, i_star = torch.max(accumulated_theta, dim=-1)  # N, N
    theta_star = theta_star.unsqueeze(-1)  # N * 1
    numer = Z_sym_diff_ct[i_star]  # N * N_arms
    accumulated_theta = torch.clamp(torch.pow(accumulated_theta - theta_star, 2), min=1e-8)  # N * N_arms
    c = numer / accumulated_theta  # N * N_arms
    sym_diff = Z_sym_diff[i_star]  # N * N_arms * N_dim   symmetric difference between pi* and each pi
    identity = torch.eye(N_dim).cuda()  # N_dim * N_dim
    sym_diff = torch.tensordot(sym_diff, identity, dims=([-1], [-1]))  # N * N_arms * N_dim
    c = torch.max(c.unsqueeze(-1) * sym_diff, dim=1)[0]  # N * N_arms * N_dim
    diff_c = torch.sum(c, dim=-1)  # n, impossible non-zero
    if target_l is None:
        return diff_c, c / torch.sum(c, dim=-1).unsqueeze(-1)
    else:
        target_c = 2 ** target_l
        return theta * torch.sqrt(diff_c / target_c).unsqueeze(-1), torch.sqrt(diff_c / target_c)


def rho_optimal(theta, n_it=1000, visualize=False, epsilon=1e-2, difficulty=False, optional_Z=None):
    if difficulty:
        theta_grad = theta
        theta = theta.data
    lambda_pre = torch.ones(theta.size()).float().cuda() / N_dim  # N * N_dim
    ls = []
    norms = []
    i_star = torch.argmax(torch.mm(theta, Z), dim=-1)  # N
    Zs = Z.transpose(0, 1)  # N_arms * N_dim
    Z_star = Zs[i_star, :]  # N * N_dim
    Z_diff = Z_star.unsqueeze(1) - Zs.unsqueeze(0)  # N * N_arms * N_dim
    denom = torch.sum(Z_diff * theta.unsqueeze(1), dim=-1) ** 2  # N * N_arms
    for it in range(n_it):
        lambda_pre.requires_grad = True
        lambda_star = 1 / lambda_pre  # N * N_dim
        numer = torch.sum((Z_diff ** 2) * lambda_star.unsqueeze(1), dim=-1)  # N * N_arms
        rho = numer / torch.clamp(denom, 1e-8)  # N * N_arms
        loss = torch.sum(torch.max(rho, dim=-1)[0])
        if visualize:
            ls.append(loss.data.cpu().numpy() / N_dim / theta.size(0))
            sys.stdout.write('\r% 4d  % 9.8f' %
                             (it, ls[-1]))
        loss.backward()
        if visualize:
            norms.append(torch.mean(lambda_pre.grad ** 2, dim=-1))
        grad = to_one_hot(torch.argmin(lambda_pre.grad, dim=-1), theta.size(-1))
        gamma = 2. / (it + 3)
        lambda_pre = lambda_pre.data * (1 - gamma) + gamma * grad
    if enable_plt:
        if visualize:
            plt.plot(ls)
            plt.show()
            plt.plot(norms)
            plt.ylim(0, 1e4)
            plt.show()
    if difficulty:
        lambda_pre = lambda_pre.data
        lambda_star = 1 / torch.clamp(lambda_pre, epsilon)
        denom = torch.sum(Z_diff * theta_grad.unsqueeze(1), dim=-1) ** 2
        numer = torch.sum((Z_diff ** 2) * lambda_star.unsqueeze(1), dim=-1)
        rho = numer / torch.clamp(denom, 1e-8)
        return lambda_pre, torch.max(rho, dim=-1)[0]
    else:
        return lambda_pre.data


def sample_theta(sample_size):
    Z_cpu = Z.cpu()
    count = [0 for _ in range(N_arms)]
    total = 0
    theta = torch.zeros((sample_size // 3) * N_arms, N_dim).float()
    while total < (sample_size // 3) * N_arms:
        new_theta = torch.rand(1, N_dim).float() * 2 - 1
        accumulated_theta = torch.mm(new_theta, Z_cpu).squeeze()
        ind = torch.argmax(accumulated_theta).numpy()
        if count[ind] < sample_size:
            count[ind] += 1
            theta[total:total + 1] = new_theta
            total += 1
            # sys.stdout.write('\rGenerated %d thetas' % total)
    return torch.clamp(torch.cat((torch.rand((sample_size - sample_size // 3) * N_arms, N_dim).float(),
                                  (theta + 1) / 2), dim=0), 1e-20).cuda()


class Training(nn.Module):
    def __init__(self, policy, p_optim, generator, g_optim, sampler, s_optim, l_upper, sample_size, binary,
                 num_samples, eval_approx, eval_num_theta, eval_sample_size, n_pulls=1, finetune=False, overall=None):
        super(Training, self).__init__()
        self.policy = policy
        self.generator = generator
        self.sampler = sampler
        self.l_upper = l_upper
        self.sample_size = sample_size
        self.binary = binary
        self.n_pulls = n_pulls
        self.num_samples = num_samples
        self.overall = overall
        if policy is None:
            self.policy = PolicyNet().cuda()
        if p_optim is None:
            self.p_optim = Adam(self.policy.parameters(), lr=1e-4)
        else:
            self.p_optim = p_optim
        if generator is None:
            self.it_thresh = IT_THRESH if overall is not None and game_mode == "q20" else IT_THRESH
            initial_theta = sample_theta(num_samples // N_arms)
            initial_theta = torch.log(initial_theta) - torch.log(1 - initial_theta)
            self.pre_theta = nn.Parameter(initial_theta)
        else:
            self.it_thresh = 0
            self.pre_theta = generator
        if g_optim is None:
            lr = 1e-3
            if game_mode == "q20" and overall is not None:
                # To satisfy barrier loss
                lr = 3e-3
            self.g_optim = Adam([self.pre_theta], lr=lr)
        else:
            self.g_optim = g_optim
        if l_upper is not None and sampler is None:
            self.sampler = Sampler(num=num_samples - 1).cuda()
        if s_optim is None:
            self.s_optim = Adam(self.sampler.parameters(), lr=1e-3)
        else:
            self.s_optim = s_optim
        self.lambda_entropy = LAMBDA_ENTROPY_BINARY if binary else LAMBDA_ENTROPY
        self.lambda_sample_entropy = LAMBDA_SAMPLE_ENTROPY if overall is not None and game_mode != '1d' else LAMBDA_SAMPLE_ENTROPY
        self.lambda_barrier = 1e3
        self.lambda_random_walk = 0

        # Evaluate
        self.eval_approx = eval_approx
        self.eval_num_theta = eval_num_theta
        self.eval_sample_size = eval_sample_size
        self.eval_pre_theta = sample_theta(eval_num_theta)
        self.eval_pre_theta = torch.log(self.eval_pre_theta) - torch.log(1 - self.eval_pre_theta)
        self.eval_pre_theta = nn.Parameter(self.eval_pre_theta)
        self.eval_lr = 1e-2 if game_mode != "1d" else 1e-3
        self.eval_optim = Adam([self.eval_pre_theta], lr=self.eval_lr)

        if self.overall is not None:
            self.set_levels()

    def init_eval(self, approx, sample_size):
        self.eval_approx = approx
        self.eval_sample_size = sample_size
        self.eval_pre_theta.data = sample_theta(self.eval_num_theta)
        self.eval_pre_theta.data = torch.log(self.eval_pre_theta.data) - torch.log(1 - self.eval_pre_theta.data)
        self.eval_optim = Adam([self.eval_pre_theta], lr=self.eval_lr)

    def run_overall(self, theta, num_theta):
        best_acc = None
        best_reward = None
        for policy in self.overall:
            with torch.no_grad():
                samples = torch.zeros(theta.size()).cuda().float()
                count = torch.zeros(theta.size()).cuda().float()
                for i in range(N_budget):
                    inverse_count = count.clone().detach()
                    inverse_count[inverse_count != 0] = 1. / inverse_count[inverse_count != 0]
                    X = torch.cat((samples, count, inverse_count), dim=-1)
                    p = policy(X)

                    Y_ind = torch.multinomial(torch.softmax(p.data, dim=-1), 1).squeeze()

                    expand_theta = theta
                    theta_sample = torch.gather(expand_theta, 1, Y_ind.unsqueeze(1))
                    ob = (theta_sample.data + 1) / 2
                    ob = torch.bernoulli(ob)
                    sample = torch.gather(samples, 1, Y_ind.unsqueeze(1))
                    ob = ob * 2 - 1 + sample
                    samples.scatter_(1, Y_ind.unsqueeze(1), ob)

                    Y_vec = to_one_hot(Y_ind, N_dim)
                    count = count + Y_vec
                count[count == 0] = 1
                samples = samples / count
                accumulated_theta = torch.mm(theta, Z)
                accumulated_samples = torch.mm(samples, Z)

                max_ind = argmax_uniform(accumulated_samples, dim=-1)
                reward_max = torch.gather(accumulated_theta, 1, max_ind.unsqueeze(1)).squeeze() - \
                             torch.max(accumulated_theta, dim=-1)[0]
                acc_max = (max_ind == torch.argmax(accumulated_theta, dim=-1)).float() * R_lim
                reward_max = torch.mean(reward_max.view(-1, num_theta), dim=0)
                acc_max = torch.mean(acc_max.view(-1, num_theta), dim=0)
                if best_reward is None:
                    best_reward, best_acc = reward_max, acc_max
                else:
                    torch.max(best_reward, reward_max, out=best_reward)
                    torch.max(best_acc, acc_max, out=best_acc)
        return best_reward, best_acc

    # def run_overall_oracle(self, theta, num_theta, static_alloc):
    #     with torch.no_grad():
    #         count = torch.distributions.multinomial.Multinomial(self.n_pulls * N_budget,
    #                                                             static_alloc).sample()
    #         expand_theta = theta
    #         sample = torch.distributions.binomial.Binomial(total_count=count, probs=(expand_theta + 1) / 2).sample()
    #         samples = sample * 2 - count
    #         count[count == 0] = 1
    #         samples = samples / count
    #         accumulated_theta = torch.mm(theta, Z)
    #         accumulated_samples = torch.mm(samples, Z)
    #
    #         max_ind = argmax_uniform(accumulated_samples, dim=-1)
    #         reward_max = torch.gather(accumulated_theta, 1, max_ind.unsqueeze(1)).squeeze() - \
    #                      torch.max(accumulated_theta, dim=-1)[0]
    #         acc_max = (max_ind == torch.argmax(accumulated_theta, dim=-1)).float() * R_lim
    #         reward_max = torch.mean(reward_max.view(-1, num_theta), dim=0)
    #         acc_max = torch.mean(acc_max.view(-1, num_theta), dim=0)
    #
    #     return reward_max, acc_max

    def set_levels(self):
        if game_mode == "1d":
            levels = torch.arange(3, self.l_upper + .1, .5).float().cuda()
        elif game_mode == "q20":
            # levels = torch.arange(3, self.l_upper + .1, .5).float().cuda()
            levels = torch.arange(3.5, self.l_upper + .1, .5).float().cuda()
        elif game_mode == "jester":
            levels = torch.arange(3, self.l_upper + .1, 1).float().cuda()
        self.train_upper = levels.repeat(self.pre_theta.size(0) // levels.size(0) + 1)[:self.pre_theta.size(0)]
        self.eval_upper = levels.repeat(self.eval_pre_theta.size(0) // levels.size(0) + 1)[:self.eval_pre_theta.size(0)]

    @staticmethod
    def run_overall_f_star(log_difficulty):
        with torch.no_grad():
            if game_mode == "1d":
                # levels = [3, 3.5, 4, 4.5, 5, 5.5, 6, 6.5, 7]
                perfs = [[0.9462, 0.7913, 0.6105, 0.4502, 0.3317, 0.2481, 0.1712, 0.1353, 0.1161],
                         [0.9451, 0.8505, 0.7093, 0.5452, 0.4061, 0.3195, 0.2474, 0.1942, 0.1553],
                         [0.9245, 0.8412, 0.7223, 0.5594, 0.4427, 0.3370, 0.2590, 0.1889, 0.1610],
                         [0.8947, 0.7992, 0.6916, 0.5863, 0.4598, 0.3528, 0.2672, 0.2148, 0.1706],
                         [0.8727, 0.7617, 0.6780, 0.5901, 0.4845, 0.3655, 0.2869, 0.2340, 0.1828],
                         [0.8434, 0.7178, 0.6323, 0.5634, 0.4753, 0.3762, 0.3005, 0.2422, 0.2002],
                         [0.8518, 0.7372, 0.6235, 0.5246, 0.4422, 0.3669, 0.3068, 0.2412, 0.2026],
                         [0.8460, 0.7153, 0.6216, 0.5141, 0.4275, 0.3511, 0.2816, 0.2359, 0.2003],
                         [0.8514, 0.7193, 0.5983, 0.4947, 0.4060, 0.3287, 0.2780, 0.2300, 0.1919]]
                perfs = torch.from_numpy(np.max(np.array(perfs), axis=0)).cuda().float() * R_lim
                log_difficulty_ind = torch.clamp(torch.ceil(log_difficulty * 2 - 6), 0, 8).long()
                acc_max = perfs[log_difficulty_ind]
                reward_max = torch.zeros(log_difficulty.size()).cuda().float()
            elif game_mode == "q20":
                # levels = [3, 3.5, 4, 4.5, 5]
                # perfs = [[0.4538, 0.2773, 0.1234, 0.0642, 0.0366],
                #          [0.3898, 0.2381, 0.1332, 0.0677, 0.0426],
                #          [0.3436, 0.2267, 0.1387, 0.0791, 0.0409],
                #          [0.3115, 0.2145, 0.1192, 0.0748, 0.0476]]
                # perfs = torch.from_numpy(np.max(np.array(perfs), axis=0)).cuda().float() * R_lim
                # log_difficulty_ind = torch.clamp(torch.ceil(log_difficulty * 2 - 6), 0, 4).long()
                perfs = [[0.2773, 0.1234, 0.0642, 0.0366],
                         [0.2381, 0.1332, 0.0677, 0.0426],
                         [0.2267, 0.1387, 0.0791, 0.0409],
                         [0.2145, 0.1192, 0.0748, 0.0476]]
                perfs = torch.from_numpy(np.max(np.array(perfs), axis=0)).cuda().float() * R_lim
                log_difficulty_ind = torch.clamp(torch.ceil(log_difficulty * 2 - 7), 0, 3).long()
                acc_max = perfs[log_difficulty_ind]
                reward_max = torch.zeros(log_difficulty.size()).cuda().float()
            elif game_mode == "jester":
                # levels = [3, 4, 5, 6, 7]
                perfs = [[-2.065, -4.888, -6.113, -6.957, -7.322],
                         [-4.223, -6.960, -7.450, -8.144, -8.422],
                         [-2.380, -5.103, -6.041, -6.215, -6.433],
                         [-4.566, -7.000, -7.428, -8.118, -8.308]]
                perf = [-2.065, -4.888, -6.041, -6.215, -6.433]
                perfs = torch.from_numpy(np.array(perf)).cuda().float()
                log_difficulty_ind = torch.clamp(torch.ceil(log_difficulty - 3), 0, 4).long()
                acc_max = torch.zeros(log_difficulty.size()).cuda().float()
                reward_max = perfs[log_difficulty_ind]
        return reward_max, acc_max

    def forward(self, sample_ind, it, sample_entropy=None, evaluate=False, fix_theta=None):
        if evaluate and fix_theta is not None:
            num_theta = fix_theta.size(0)
        else:
            num_theta = sample_ind.size(0)
        if evaluate:
            if fix_theta is None:
                select_pre_theta = self.eval_pre_theta[sample_ind]
            else:
                theta = fix_theta
                select_pre_theta = fix_theta
        else:
            if it >= self.it_thresh:
                sample_weight = self.sampler().view(-1)
                log_sample_prob = torch.log_softmax(sample_weight, dim=-1)
            select_pre_theta = self.pre_theta[sample_ind]
        if fix_theta is None:
            theta = torch.sigmoid(select_pre_theta) * 2 - 1
        log_theta = F.logsigmoid(select_pre_theta)
        log_theta_neg = F.logsigmoid(-select_pre_theta)
        if self.l_upper is not None:
            if self.eval_approx or not evaluate:
                c, static_alloc = c_fn(theta)
            else:
                static_alloc, c = rho_optimal(theta, difficulty=True)
            l = torch.log2(c)
            if self.overall is None:
                l_upper = self.l_upper
            else:
                if evaluate:
                    l_upper = self.eval_upper[sample_ind]
                else:
                    l_upper = self.train_upper[sample_ind]
            in_range = l < l_upper
            barrier_loss = (l - l_upper) * (l > l_upper).float()

            # We believe the random walk is not having an effect for the training.
            # But to reproduce our results, we keep it here.
            random_walk = F.mse_loss(theta,
                                     theta.data + torch.randn(theta.size()).float().cuda(),
                                     reduction='none')
            barrier_loss = barrier_loss * self.lambda_barrier + torch.sum(random_walk, dim=-1) * self.lambda_random_walk
            barrier_loss[in_range] = 0

        if evaluate:
            sample_size = self.eval_sample_size
        else:
            sample_size = self.sample_size
            log_sum_p = 0
            entropy = 0
        if self.overall is not None:
            if self.overall == "f_star":
                reward_baseline, acc_baseline = self.run_overall_f_star(l.data)
                reward_baseline = reward_baseline.unsqueeze(0).repeat(sample_size, 1).view(-1)
                acc_baseline = acc_baseline.unsqueeze(0).repeat(sample_size, 1).view(-1)
            else:
                reward_baseline, acc_baseline = self.run_overall(theta.repeat(sample_size, 1), num_theta)
                reward_baseline = reward_baseline.unsqueeze(0).repeat(sample_size, 1).view(-1)
                acc_baseline = acc_baseline.unsqueeze(0).repeat(sample_size, 1).view(-1)
        samples = torch.zeros(num_theta * sample_size, N_dim).cuda().float()
        count = torch.zeros(num_theta * sample_size, N_dim).cuda().float()
        log_sum_g = 0
        for i in range(N_budget):
            inverse_count = count.clone().detach()
            inverse_count[inverse_count != 0] = 1. / inverse_count[inverse_count != 0]
            X = torch.cat((samples, count, inverse_count), dim=-1)
            if evaluate:
                with torch.no_grad():
                    p = self.policy(X)
            else:
                p = self.policy(X)

            if not evaluate:
                log_probs = torch.log_softmax(p, dim=-1)
                entropy = entropy - torch.mean(
                    torch.sum(log_probs * torch.softmax(p, dim=-1), dim=-1).view(sample_size, num_theta),
                    dim=0) / N_budget
            if self.n_pulls == 1:
                Y_ind = torch.multinomial(torch.softmax(p.data, dim=-1), 1).squeeze()
                if not evaluate:
                    log_prob = torch.gather(log_probs, 1, Y_ind.unsqueeze(1)).squeeze()
                    log_sum_p = log_sum_p + log_prob
                expand_theta = theta.repeat(sample_size, 1)
                expand_log_theta = log_theta.repeat(sample_size, 1)
                expand_log_theta_neg = log_theta_neg.repeat(sample_size, 1)
                theta_sample = torch.gather(expand_theta, 1, Y_ind.unsqueeze(1))
                log_theta_smaple = torch.gather(expand_log_theta, 1, Y_ind.unsqueeze(1))
                log_theta_neg_sample = torch.gather(expand_log_theta_neg, 1, Y_ind.unsqueeze(1))
                ob = (theta_sample.data + 1) / 2
                ob = torch.bernoulli(ob)
                log_prob = log_theta_smaple * ob + log_theta_neg_sample * (1 - ob)
                log_sum_g = log_sum_g + log_prob.view(-1)
                sample = torch.gather(samples, 1, Y_ind.unsqueeze(1))
                ob = ob * 2 - 1 + sample
                samples.scatter_(1, Y_ind.unsqueeze(1), ob)

                Y_vec = to_one_hot(Y_ind, N_dim)
                count.add_(Y_vec)
            else:
                Y_vec = torch.distributions.multinomial.Multinomial(self.n_pulls,
                                                                    torch.softmax(p.data, dim=-1)).sample()
                if not evaluate:
                    log_prob = torch.sum(log_probs * Y_vec, dim=-1) / self.n_pulls
                    log_sum_p = log_sum_p + log_prob
                expand_theta = theta.repeat(sample_size, 1)
                expand_log_theta = log_theta.repeat(sample_size, 1)
                expand_log_theta_neg = log_theta_neg.repeat(sample_size, 1)
                sample = torch.distributions.binomial.Binomial(total_count=Y_vec,
                                                               probs=(expand_theta + 1) / 2).sample()
                sample_neg = Y_vec - sample
                log_prob = torch.sum(sample * expand_log_theta + sample_neg * expand_log_theta_neg, dim=-1)
                log_sum_g = log_sum_g + log_prob / self.n_pulls
                samples.add_(sample * 2 - Y_vec)
                count.add_(Y_vec)
        count[count == 0] = 1
        samples.div_(count)
        accumulated_theta = torch.mm(expand_theta, Z)
        accumulated_samples = torch.mm(samples, Z)

        max_ind = argmax_uniform(accumulated_samples, dim=-1)
        reward_max = torch.gather(accumulated_theta, 1, max_ind.unsqueeze(1)).squeeze() - \
                     torch.max(accumulated_theta, dim=-1)[0]
        # reward_max = reward_max * 1e2
        acc_max = (max_ind == torch.argmax(accumulated_theta, dim=-1)).float() * R_lim

        if fix_theta is not None:
            print()
            print(torch.mean(reward_max).data.cpu().numpy(), torch.min(reward_max).data.cpu().numpy())
            print("Accuracy:", torch.mean(acc_max).data.cpu().numpy() / R_lim,
                  torch.min(acc_max).data.cpu().numpy() / R_lim)
            if self.binary:
                return torch.mean(acc_max.view(sample_size, num_theta), dim=0).cpu().numpy() / R_lim
            else:
                return torch.mean(reward_max.view(sample_size, num_theta), dim=0).cpu().numpy()

        if self.overall is not None:
            acc_max = acc_max - acc_baseline
            reward_max = reward_max - reward_baseline

        if evaluate:
            if self.l_upper is None:
                sys.stdout.write('\r% 4d  % 9.8f' %
                                 (it, torch.mean(acc_max).data.cpu().numpy() / R_lim if self.binary
                                 else torch.mean(reward_max).data.cpu().numpy()))
            else:
                sys.stdout.write('\r% 4d  % 9.8f  % 9.8f  % 5d' %
                                 (it,
                                  torch.mean(acc_max).data.cpu().numpy() / R_lim if self.binary
                                  else torch.mean(reward_max).data.cpu().numpy(),
                                  torch.min(torch.mean(
                                      (acc_max / R_lim if self.binary else reward_max).view(sample_size, num_theta),
                                      dim=0) + (1 - in_range.float()) * 1000).data.cpu().numpy(),
                                  torch.sum(in_range).data.cpu().numpy()))
        else:
            if self.l_upper is None:
                sys.stdout.write('\r% 5d  % 9.8f  % 9.8f' %
                                 (it,
                                  torch.mean(acc_max).data.cpu().numpy() / R_lim if self.binary
                                  else torch.mean(reward_max).data.cpu().numpy(),
                                  entropy.data.cpu().numpy(),
                                  ))
            else:
                sys.stdout.write('\r% 5d  % 9.8f  % 9.8f  % 9.8f  % 5d' %
                                 (it,
                                  torch.mean(acc_max).data.cpu().numpy() / R_lim if self.binary
                                  else torch.mean(reward_max).data.cpu().numpy(),
                                  torch.mean(entropy.data).cpu().numpy(),
                                  sample_entropy if it >= self.it_thresh else 0,
                                  torch.sum(in_range).data.cpu().numpy()
                                  ))
        if enable_plt:
            print(" ", torch.max(c).data.cpu().numpy())
        sys.stdout.flush()
        if self.binary:
            avg_acc_max = torch.mean(acc_max.view(sample_size, num_theta), dim=0)
            return1 = avg_acc_max.data / R_lim
            g_loss = torch.mean((acc_max * log_sum_g).view(sample_size, num_theta), dim=0)
        else:
            avg_reward_max = torch.mean(reward_max.view(sample_size, num_theta), dim=0)
            return1 = avg_reward_max.data
            g_loss = torch.mean((reward_max + reward_max.data * log_sum_g).view(sample_size, num_theta),
                                dim=0)

        if self.l_upper is not None:
            g_loss = barrier_loss + g_loss
        if evaluate:
            return return1, g_loss, in_range, theta.data
        else:
            if it >= self.it_thresh:
                log_sample_prob = log_sample_prob[sample_ind]
                s_loss = (avg_acc_max if self.binary else avg_reward_max.data) * log_sample_prob
            else:
                s_loss = None
            p_loss = torch.mean(
                (-(acc_max if self.binary else reward_max.data) * log_sum_p).view(self.sample_size, num_theta),
                dim=0) - entropy * self.lambda_entropy
            return return1, g_loss, s_loss, p_loss, theta.data


def train(policy, p_optim, generator, g_optim, num_theta, n_it, sample_size=1,
          l_upper=None, sampler=None, s_optim=None, binary=False, num_checks=10,
          fname='./minimax_policy.dict', num_samples=1, finetune=False, overall=None, n_pulls=1):
    PT = Training(policy, p_optim, generator, g_optim, sampler, s_optim, l_upper, sample_size,
                  binary, num_samples, True, 16000 if game_mode == "1d" else 1600, 100, n_pulls=n_pulls,
                  finetune=finetune, overall=overall).cuda()
    rewards = []
    worst_rewards = []
    for it in range(n_it):
        if it >= PT.it_thresh:
            sample_weight = PT.sampler().view(-1)
            log_sample_prob = torch.log_softmax(sample_weight, dim=-1)
            sample_prob = torch.softmax(sample_weight, dim=-1)
            sample_entropy = torch.sum(-log_sample_prob * sample_prob)
            sample_entropy_numpy = sample_entropy.data.cpu().numpy()
            sample_ind = torch.multinomial(sample_prob.data, num_theta, replacement=True)
        else:
            sample_entropy_numpy = None
            sample_ind = torch.randint(num_samples, (num_theta,)).cuda()
        return1, g_loss, s_loss, p_loss, theta = PT(sample_ind, it, sample_entropy=sample_entropy_numpy)
        rewards.append(return1.cpu().numpy())
        PT.g_optim.zero_grad()
        g_loss = torch.sum(g_loss)
        g_loss.backward(retain_graph=True)
        PT.g_optim.step()

        if it >= PT.it_thresh:
            PT.s_optim.zero_grad()
            s_loss = torch.mean(s_loss) - sample_entropy * PT.lambda_sample_entropy
            s_loss.backward()
            PT.s_optim.step()

        PT.p_optim.zero_grad()
        p_loss = torch.mean(p_loss)
        p_loss.backward()
        if it >= PT.it_thresh or overall is None:
            PT.p_optim.step()

        if enable_plt:
            if (it % (n_it // (2 * num_checks)) == 0) or it == PT.it_thresh - 1:
                plt.hist(argmax_uniform(torch.mm(theta, Z), dim=-1).data.cpu().numpy(), N_arms)
                plt.show()
        if it != 0 and (it % (n_it // num_checks) == 0 or it == n_it - 1):
            if it >= PT.it_thresh:
                del s_loss
            del g_loss, p_loss
            print()
            # worst_rewards.append(evaluate_model_sampling(PT, 100, 3000, plot=True, l_upper=l_upper)[0])
            worst_rewards.append(
                evaluate_model_sha(PT, 1000 if game_mode == '1d' else 100, .25, 3, 100, 100, plot=True, approx=True)[
                    0])
            print(worst_rewards[-1])
            if worst_rewards[-1] == max(worst_rewards):
                print('Saved')
                torch.save({'policy': PT.policy,
                            'generator': PT.pre_theta,
                            'sampler': PT.sampler,
                            'p_optim': PT.p_optim,
                            'g_optim': PT.g_optim,
                            's_optim': PT.s_optim},
                           fname)
    if enable_plt:
        plt.plot(rewards)
        plt.show()
        plt.plot(worst_rewards)
        plt.show()
        for i in range(20):
            plt.subplot(5, 4, i + 1)
            plt.plot(theta[i].data.cpu().numpy())
        plt.show()
    return sum(rewards[-10:]) / 10, policy, p_optim, generator, g_optim, sampler, s_optim


def evaluate_model_sampling(PT, sample_size, n_it, plot=False, l_upper=None, fix_theta=None, approx=True):
    PT.init_eval(approx, sample_size)
    worst_theta = None
    worst_rewards = None
    if fix_theta is not None:
        n_it = 1
    rewards = []
    magnitudes = []
    losses = []
    inds = torch.arange(PT.eval_num_theta).cuda()
    for it in range(n_it):
        if fix_theta is not None:
            return PT(inds, it, evaluate=True, fix_theta=fix_theta)
        return1, g_loss, in_range, theta = PT(inds, it, evaluate=True, fix_theta=fix_theta)
        rewards.append(return1.cpu().numpy())
        PT.eval_optim.zero_grad()
        g_loss = torch.sum(g_loss)
        losses.append(g_loss.data.cpu().numpy())
        g_loss.backward()
        PT.eval_optim.step()
        if plot:
            magnitudes.append(torch.sqrt(torch.max(PT.eval_pre_theta.grad ** 2)).cpu().numpy() * PT.eval_lr)
    del g_loss
    print()
    rewards = np.array(rewards)
    if l_upper is not None:
        in_range = in_range.cpu().numpy()
    else:
        in_range = np.ones(rewards.shape[-1])
    min_theta = np.argmin(np.mean(rewards[-50:, :], axis=0) * in_range + 1000 * (1 - in_range))
    if enable_plt:
        if plot:
            plt.plot(rewards[:, min_theta])
            plt.show()
            plt.plot(magnitudes)
            plt.show()
            plt.plot(theta[min_theta].data.cpu().numpy())
            plt.show()
            plt.plot(losses)
            plt.show()
    worst_reward = np.mean(rewards[-50:, min_theta])
    if plot:
        worst_theta = theta[min_theta].data
        worst_rewards = rewards
    if plot:
        return worst_reward, worst_theta, worst_rewards
    else:
        return worst_reward


def evaluate_model_sha(PT, num_theta, reduction, n_rung, sample_size, n_it, plot=False, approx=True):
    PT.init_eval(approx, sample_size)
    worst_theta = None
    worst_rewards = None
    losses = []
    it = 0
    batch_num_theta = num_theta
    num_theta = int(num_theta / (reduction ** (n_rung - 1)))
    inds = torch.arange(num_theta).cuda()
    for i in range(n_rung):
        rewards = []
        for j in range((num_theta - 1) // batch_num_theta + 1):
            ind_set = inds[j * batch_num_theta:min((j + 1) * batch_num_theta, num_theta)]
            rs = []
            for _ in range(n_it):
                it += 1
                return1, g_loss, in_range, theta = PT(ind_set, it, evaluate=True)
                rs.append(return1.data.unsqueeze(-1))
                PT.eval_optim.zero_grad()
                g_loss = torch.sum(g_loss)
                losses.append(g_loss.data.cpu().numpy())
                g_loss.backward()
                PT.eval_optim.step()
            in_range = in_range.float()
            rewards.append(torch.cat(rs, dim=-1) * in_range.unsqueeze(-1) + 1000 * (1 - in_range.unsqueeze(-1)))
        rewards = torch.cat(rewards, dim=0)[:, -min(n_it // 4, 100):]
        rewards = torch.mean(rewards, dim=-1)
        num_theta = int(num_theta * reduction)
        if i != n_rung - 1:
            inds = inds[torch.topk(rewards, num_theta, largest=False, sorted=False)[1]]
            n_it = int(n_it / reduction)
    del g_loss
    rewards = rewards.cpu().numpy()
    min_theta = np.argmin(rewards)
    worst_reward = rewards[min_theta]
    print()
    if plot:
        worst_theta = torch.sigmoid(PT.eval_pre_theta[inds[int(min_theta)]].data) * 2 - 1
        worst_rewards = rewards
    if plot:
        return worst_reward, worst_theta, worst_rewards
    else:
        return worst_reward


if __name__ == "__main__":
    assert len(sys.argv) == 4, "Usage: python minimax_policy.py game_mode model_id l_upper"
    l_upper = float(sys.argv[3])
    model_id = int(sys.argv[2])
    fname = '../model/minimax_policy_%s_%d_%s.dict' % (game_mode, model_id, float2str(l_upper))
    print("Saving file to", fname)
    reward, p, po, g, go, s, so = train(None, None, None, None, M, N_it, sample_size=L, binary=False,
                                        l_upper=l_upper, fname=fname, num_samples=N_arms * N_GEN,
                                        n_pulls=N_pulls)  # , overall="oracle"
    minimax_dict = torch.load(fname)
