import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, SGD
from torch.optim.lr_scheduler import StepLR
from torch import distributions
import numpy as np
import sys
from src.utils import *
from src.q20game import Q20Env
from src.hyperparam import get_param
from src.jester_loader import get_z

enable_plt = False
if enable_plt:
    import matplotlib.pyplot as plt

game_mode = "1d"

IT_THRESH, N_it, N_dim, N_budget, Z, LAMBDA_ENTROPY, LAMBDA_ENTROPY_BINARY, LAMBDA_SAMPLE_ENTROPY, M, L, N_GEN, \
R_lim, N_it_overall, N_pulls, eval_range = get_param(game_mode)
N_it = 500
if game_mode == "q20":
    QEnv = Q20Env()
    theta_probs, _ = QEnv.get_probs()
    theta_probs = torch.from_numpy(theta_probs).float().cuda()
    people_probs = np.array([p[1] for p in QEnv.posList])
    people_probs /= np.sum(people_probs)
    people_probs = torch.from_numpy(people_probs).float().cuda()
elif game_mode == "jester":
    theta_probs = get_z(return_data=True)[1]
    theta_probs = torch.from_numpy(theta_probs).float().cuda() / 10
    people_probs = torch.ones(theta_probs.size(0)).float().cuda() / theta_probs.size(0)
elif game_mode == "1d":
    theta_probs = Z.t() * 2 - 1
    people_probs = torch.ones(theta_probs.size(0)).float().cuda() / theta_probs.size(0)
N_arms = Z.shape[1]


class PolicyNet(nn.Module):
    def __init__(self):
        super(PolicyNet, self).__init__()
        self.l1 = nn.Linear(N_dim * 3, 256)
        self.l2 = nn.Linear(256, 256)
        self.l3 = nn.Linear(256, 256)
        self.l4 = nn.Linear(256, 256)
        self.l5 = nn.Linear(256, N_dim)
        weight_initialize_nn(self, sigma=.01)

    def forward(self, X):
        X = F.leaky_relu(self.l1(X))
        X = F.leaky_relu(self.l2(X))
        X = F.leaky_relu(self.l3(X))
        X = F.leaky_relu(self.l4(X))
        p = self.l5(X)
        return p


class ParallelTraining(nn.Module):
    def __init__(self, policy, p_optim, sample_size, binary, eval_sample_size):
        super(ParallelTraining, self).__init__()
        self.policy = policy
        self.sample_size = sample_size
        self.binary = binary
        if policy is None:
            self.policy = PolicyNet().cuda()
        if p_optim is None:
            self.p_optim = Adam(self.policy.parameters(), lr=1e-3)
        else:
            self.p_optim = p_optim
        # Evaluate
        self.eval_sample_size = eval_sample_size

    def forward(self, sample_ind, it, evaluate=False, h=None):
        if evaluate:
            num_theta = sample_ind.size(0)
            theta = theta_probs[sample_ind]
            if h is not None:
                theta = theta * h
        else:
            num_theta = sample_ind.size(0)
            theta = (Z.t() * 2 - 1)[sample_ind]
        if evaluate:
            sample_size = self.eval_sample_size
        else:
            sample_size = self.sample_size
        samples = torch.zeros(num_theta * sample_size, N_dim).cuda().float()
        count = torch.zeros(num_theta * sample_size, N_dim).cuda().float()
        log_sum_p = 0
        for i in range(N_budget):
            inverse_count = count.clone().detach()
            inverse_count[inverse_count != 0] = 1. / inverse_count[inverse_count != 0]
            X = torch.cat((samples, count, inverse_count), dim=-1)
            if evaluate:
                with torch.no_grad():
                    p = self.policy(X)
            else:
                p = self.policy(X)

            Y_ind = torch.multinomial(torch.softmax(p.data, dim=-1), 1).squeeze()
            if not evaluate:
                log_probs = torch.log_softmax(p, dim=-1)
                log_prob = torch.gather(log_probs, 1, Y_ind.unsqueeze(1)).squeeze()
                log_sum_p = log_sum_p + log_prob

            expand_theta = theta.repeat(sample_size, 1)
            theta_sample = torch.gather(expand_theta, 1, Y_ind.unsqueeze(1))
            ob = (theta_sample.data + 1) / 2
            ob = torch.bernoulli(ob)
            sample = torch.gather(samples, 1, Y_ind.unsqueeze(1))
            ob = ob * 2 - 1 + sample
            samples.scatter_(1, Y_ind.unsqueeze(1), ob)

            Y_vec = to_one_hot(Y_ind, N_dim)
            count = count + Y_vec
        count[count == 0] = 1
        samples = samples / count
        accumulated_theta = torch.mm(expand_theta, Z)
        accumulated_samples = torch.mm(samples, Z)

        max_ind = argmax_uniform(accumulated_samples, dim=-1)
        reward_max = torch.gather(accumulated_theta, 1, max_ind.unsqueeze(1)).squeeze() - \
                     torch.max(accumulated_theta, dim=-1)[0]
        acc_max = (max_ind == torch.argmax(accumulated_theta, dim=-1)).float() * R_lim

        if evaluate:
            if self.binary:
                avg_acc_max = torch.mean(acc_max.view(sample_size, num_theta), dim=0)
                return torch.sum(avg_acc_max * people_probs[sample_ind]).data.cpu().numpy() / R_lim
            else:
                avg_reward_max = torch.mean(reward_max.view(sample_size, num_theta), dim=0)
                return torch.sum(avg_reward_max * people_probs[sample_ind]).data.cpu().numpy()

        sys.stdout.write('\r% 5d  % 9.8f' %
                         (it,
                          torch.mean(acc_max).data.cpu().numpy() / R_lim if self.binary
                          else torch.mean(reward_max).data.cpu().numpy(),
                          ))
        sys.stdout.flush()
        if self.binary:
            avg_acc_max = torch.mean(acc_max.view(sample_size, num_theta), dim=0)
            return1 = avg_acc_max.data / R_lim
        else:
            avg_reward_max = torch.mean(reward_max.view(sample_size, num_theta), dim=0)
            return1 = avg_reward_max.data

        p_loss = torch.mean(
            (-(acc_max if self.binary else reward_max.data) * log_sum_p).view(self.sample_size, num_theta),
            dim=0)
        return return1, p_loss


def train(policy, p_optim, num_theta, N_it, sample_size=1, binary=False, num_checks=1, fname='./minimax_policy.dict'):
    PT = ParallelTraining(policy, p_optim, sample_size, binary, 100).cuda()
    PTDP = PT
    rewards = []
    eval_rewards = []
    for it in range(N_it):
        sample_ind = torch.multinomial(torch.ones(Z.size(1)).float().cuda(), num_theta, replacement=True).squeeze()
        return1, p_loss = PTDP(sample_ind, it)
        rewards.append(return1.cpu().numpy())

        PT.p_optim.zero_grad()
        p_loss = torch.mean(p_loss)
        p_loss.backward()
        PT.p_optim.step()

        if it != 0 and (it % (N_it // num_checks) == 0 or it == N_it - 1):
            print()
            eval_rewards.append(evaluate_model_sampling(PT, PTDP, 100))
            if eval_rewards[-1] == max(eval_rewards):
                torch.save({'policy': PT.policy,
                            'p_optim': PT.p_optim},
                           fname)
                print('Saved')
    return sum(rewards[-10:]) / 10, policy, p_optim


def evaluate_model_sampling(PT, PTDP, sample_size):
    if game_mode == "1d":
        for h in [.5, .6, .7, .8, .9, 1]:
            print("h =", h)
            inds = torch.arange(theta_probs.size(0)).cuda()
            PT.eval_sample_size = sample_size
            sum = 0
            for i in range(int(np.ceil(inds.size(0) / 1000.))):
                sum += PTDP(inds[i * 1000: min(inds.size(0), (i + 1) * 1000)], 1, evaluate=True, h=h)
            print(sum)
        return sum
    else:
        inds = torch.arange(theta_probs.size(0)).cuda()
        PT.eval_sample_size = sample_size
        sum = 0
        for i in range(int(np.ceil(inds.size(0) / 1000.))):
            sum += PTDP(inds[i*1000: min(inds.size(0), (i+1)*1000)], 1, evaluate=True)
        print(sum)
        return sum


if __name__ == "__main__":
    torch.manual_seed(12345)
    fname = '../model/policy_gradient_1.dict'
    print("Saving file to", fname)
    reward, p, po = train(None, None, 500, N_it, sample_size=30, fname=fname, binary=False if game_mode == "jester" else True)
    minimax_dict = torch.load(fname)
    print(reward)
