import numpy as np
from gym import spaces
import torch
from src.utils import to_one_hot
import matplotlib.pyplot as plt

action2qid_file = "data/action2qid.question500.txt"
pos_file = "data/pos.people1000.txt"
rightprob_file = "data/rightProb.people1000.question500.txt"
wrongprob_file = "data/wrongProb.people1000.question500.txt"


def XY_optimal(Z, n_it=1000):
    lambda_pre = torch.ones(Z.size(-1)).float() / Z.size(-1)  # N_dim
    Z_diff = Z.unsqueeze(1) - Z.unsqueeze(0)  # N_arms * N_arms * N_dim  Z' * Z * d
    denom = torch.clamp(torch.sum(Z_diff * (2 * Z.unsqueeze(1) - 1), dim=-1) ** 2, 1e-8)
    Z_diff = Z_diff ** 2
    ls = []
    for it in range(n_it):
        lambda_pre.requires_grad = True
        lambda_star = 1 / lambda_pre  # N_dim
        numer = torch.tensordot(Z_diff, lambda_star, dims=([-1], [-1]))  # N_arms * N_arms
        loss = torch.max(numer / denom)
        ls.append(loss.data.numpy())
        print(ls[-1])
        loss.backward()
        gamma = 2. / (it + 3)
        min_ind = torch.argmin(lambda_pre.grad)
        lambda_pre = lambda_pre.data * (1 - gamma)
        lambda_pre[min_ind] += gamma
    return lambda_pre.data


class Q20Env:

    def __init__(self):
        self.state_pid_dict = None
        self.pid_state_dict = None
        self.action_qid_dict = None
        self.qid_action_dict = None

        self.posList = None
        self.rightProb = None
        self.wrongProb = None
        self.arm_pid = None

        self.load_pos(pos_file)
        self.load_state_pid_dict()
        self.load_action_qid_dict(action2qid_file)
        self.load_rightprob(rightprob_file)
        self.load_wrongprob(wrongprob_file)

        # state, action for policy network
        self.arm_pid = np.arange(len(self.posList))
        self.q_inds = np.arange(len(self.action_qid_dict))
        self.get_probs(original=True)
        self.state = [p[1] for p in self.posList]
        self.action_space = spaces.Discrete(len(self.q_inds))

        self.threshold_prob = 1.0
        self.threshold_times = 100
        self.max_step_num = 20
        self.turns = 0

        self.guess_idx = np.random.randint(0, len(self.state))

    def get_probs(self, original=False, trim=True):
        right_probs = []
        wrong_probs = []
        for i in range(len(self.state_pid_dict)):
            right_probs.append(self.rightProb[self.state_pid_dict[i]])
            wrong_probs.append(self.wrongProb[self.state_pid_dict[i]])
        if not trim:
            right_probs = np.array(right_probs)
            wrong_probs = np.array(wrong_probs)
            right_probs = right_probs / (right_probs + wrong_probs)
            arms = right_probs > .5
            return right_probs, arms
        rand = np.random.RandomState(123)
        # right_probs = np.array(right_probs)
        # wrong_probs = np.array(wrong_probs)
        # right_probs = right_probs / (right_probs + wrong_probs)
        # rho = XY_optimal(torch.from_numpy(right_probs > .5).float())
        # print(rho)
        rho = np.array([0.0015, 0.0017, 0.0022, 0.0016, 0.0014, 0.0012, 0.0007, 0.0005, 0.0004,
                        0.0018, 0.0007, 0.0020, 0.0014, 0.0012, 0.0018, 0.0009, 0.0008, 0.0020,
                        0.0022, 0.0010, 0.0006, 0.0014, 0.0007, 0.0011, 0.0010, 0.0011, 0.0023,
                        0.0009, 0.0002, 0.0010, 0.0040, 0.0022, 0.0022, 0.0018, 0.0013, 0.0009,
                        0.0016, 0.0030, 0.0012, 0.0004, 0.0037, 0.0020, 0.0009, 0.0013, 0.0017,
                        0.0009, 0.0010, 0.0015, 0.0009, 0.0020, 0.0009, 0.0014, 0.0019, 0.0012,
                        0.0024, 0.0008, 0.0017, 0.0026, 0.0040, 0.0010, 0.0031, 0.0006, 0.0023,
                        0.0020, 0.0085, 0.0006, 0.0037, 0.0004, 0.0034, 0.0016, 0.0009, 0.0008,
                        0.0009, 0.0006, 0.0011, 0.0005, 0.0013, 0.0015, 0.0042, 0.0128, 0.0007,
                        0.0009, 0.0020, 0.0031, 0.0020, 0.0028, 0.0015, 0.0008, 0.0010, 0.0003,
                        0.0017, 0.0032, 0.0020, 0.0034, 0.0009, 0.0008, 0.0024, 0.0010, 0.0009,
                        0.0006, 0.0004, 0.0020, 0.0011, 0.0020, 0.0030, 0.0031, 0.0015, 0.0009,
                        0.0009, 0.0006, 0.0016, 0.0013, 0.0128, 0.0030, 0.0034, 0.0011, 0.0025,
                        0.0006, 0.0019, 0.0028, 0.0023, 0.0017, 0.0022, 0.0011, 0.0063, 0.0016,
                        0.0011, 0.0003, 0.0005, 0.0013, 0.0034, 0.0015, 0.0020, 0.0008, 0.0019,
                        0.0017, 0.0085, 0.0019, 0.0003, 0.0008, 0.0032, 0.0041, 0.0006, 0.0031,
                        0.0007, 0.0022, 0.0032, 0.0004, 0.0031, 0.0508, 0.0012, 0.0003, 0.0037,
                        0.0010, 0.0016, 0.0017, 0.0037, 0.0012, 0.0005, 0.0010, 0.0019, 0.0005,
                        0.0034, 0.0016, 0.0064, 0.0019, 0.0005, 0.0030, 0.0008, 0.0012, 0.0008,
                        0.0005, 0.0006, 0.0014, 0.0023, 0.0030, 0.0030, 0.0009, 0.0022, 0.0004,
                        0.0015, 0.0023, 0.0014, 0.0016, 0.0017, 0.0021, 0.0007, 0.0014, 0.0007,
                        0.0003, 0.0019, 0.0011, 0.0028, 0.0022, 0.0009, 0.0019, 0.0006, 0.0005,
                        0.0084, 0.0025, 0.0012, 0.0030, 0.0021, 0.0017, 0.0005, 0.0002, 0.0022,
                        0.0004, 0.0032, 0.0023, 0.0014, 0.0010, 0.0017, 0.0127, 0.0041, 0.0084,
                        0.0019, 0.0012, 0.0015, 0.0008, 0.0004, 0.0027, 0.0031, 0.0007, 0.0025,
                        0.0009, 0.0032, 0.0009, 0.0012, 0.0022, 0.0004, 0.0008, 0.0009, 0.0007,
                        0.0021, 0.0023, 0.0006, 0.0013, 0.0015, 0.0013, 0.0014, 0.0008, 0.0008,
                        0.0021, 0.0009, 0.0019, 0.0009, 0.0034, 0.0041, 0.0003, 0.0010, 0.0004,
                        0.0017, 0.0019, 0.0011, 0.0010, 0.0005, 0.0014, 0.0010, 0.0010, 0.0017,
                        0.0021, 0.0016, 0.0011, 0.0006, 0.0008, 0.0037, 0.0016, 0.0019, 0.0127,
                        0.0011, 0.0009, 0.0064, 0.0064, 0.0015, 0.0010, 0.0015, 0.0016, 0.0013,
                        0.0016, 0.0009, 0.0016, 0.0027, 0.0006, 0.0023, 0.0041, 0.0016, 0.0010,
                        0.0007, 0.0023, 0.0019, 0.0016, 0.0010, 0.0011, 0.0013, 0.0022, 0.0019,
                        0.0029, 0.0005, 0.0021, 0.0013, 0.0011, 0.0012, 0.0009, 0.0016, 0.0012,
                        0.0007, 0.0006, 0.0019, 0.0086, 0.0016, 0.0010, 0.0003, 0.0033, 0.0009,
                        0.0016, 0.0010, 0.0007, 0.0005, 0.0029, 0.0004, 0.0016, 0.0011, 0.0027,
                        0.0019, 0.0015, 0.0009, 0.0014, 0.0013, 0.0003, 0.0033, 0.0006, 0.0016,
                        0.0010, 0.0018, 0.0064, 0.0017, 0.0019, 0.0016, 0.0004, 0.0011, 0.0008,
                        0.0011, 0.0008, 0.0033, 0.0018, 0.0033, 0.0014, 0.0033, 0.0016, 0.0015,
                        0.0013, 0.0036, 0.0016, 0.0035, 0.0009, 0.0022, 0.0008, 0.0031, 0.0015,
                        0.0015, 0.0009, 0.0027, 0.0022, 0.0027, 0.0017, 0.0009, 0.0004, 0.0036,
                        0.0021, 0.0009, 0.0063, 0.0018, 0.0013, 0.0009, 0.0036, 0.0015, 0.0010,
                        0.0023, 0.0041, 0.0012, 0.0013, 0.0007, 0.0024, 0.0024, 0.0023, 0.0031,
                        0.0019, 0.0029, 0.0026, 0.0017, 0.0007, 0.0013, 0.0006, 0.0008, 0.0007,
                        0.0031, 0.0007, 0.0020, 0.0013, 0.0029, 0.0011, 0.0025, 0.0014, 0.0024,
                        0.0013, 0.0006, 0.0027, 0.0009, 0.0031, 0.0005, 0.0016, 0.0006, 0.0017,
                        0.0010, 0.0011, 0.0036, 0.0013, 0.0027, 0.0007, 0.0009, 0.0012, 0.0013,
                        0.0020, 0.0009, 0.0017, 0.0026, 0.0017, 0.0017, 0.0035, 0.0030, 0.0063,
                        0.0023, 0.0016, 0.0009, 0.0040, 0.0024, 0.0038, 0.0003, 0.0006, 0.0040,
                        0.0027, 0.0085, 0.0031, 0.0008, 0.0008, 0.0018, 0.0015, 0.0005, 0.0012,
                        0.0029, 0.0018, 0.0014, 0.0023, 0.0018, 0.0010, 0.0006, 0.0012, 0.0016,
                        0.0028, 0.0026, 0.0020, 0.0010, 0.0015, 0.0016, 0.0008, 0.0033, 0.0003,
                        0.0024, 0.0009, 0.0009, 0.0033, 0.0026, 0.0018, 0.0026, 0.0028, 0.0023,
                        0.0012, 0.0030, 0.0040, 0.0038, 0.0012, 0.0018, 0.0034, 0.0004, 0.0021,
                        0.0032, 0.0030, 0.0063, 0.0017, 0.0012, 0.0023, 0.0012, 0.0011, 0.0017,
                        0.0013, 0.0012, 0.0032, 0.0018, 0.0031])
        prob_inds = rand.choice(np.arange(500), size=100, replace=False, p=rho)
        right_probs = np.array(right_probs)[:, prob_inds]
        wrong_probs = np.array(wrong_probs)[:, prob_inds]
        right_probs = right_probs / (right_probs + wrong_probs)
        arms = right_probs > .5
        if original:
            self.q_inds = prob_inds
        return right_probs[self.arm_pid], arms[self.arm_pid]

    def load_state_pid_dict(self):
        self.state_pid_dict = {}
        self.pid_state_dict = {}
        for i, p in enumerate(self.posList):
            self.state_pid_dict[i] = p[0]
            self.pid_state_dict[p[0]] = i

    def load_action_qid_dict(self, action2qid_file):
        self.action_qid_dict = {}
        self.qid_action_dict = {}
        for line in open(action2qid_file):
            items = line.strip().split("\t")
            if len(items) != 2:
                continue
            action = int(items[0])
            qid = int(items[1])
            if action not in self.action_qid_dict:
                self.action_qid_dict[action] = qid
                self.qid_action_dict[qid] = action

    def load_pos(self, pos_file):
        posList = []
        for line in open(pos_file):
            items = line.strip().split("\t")
            if len(items) != 2:
                continue
            pid = int(items[0])
            pos = float(items[1])
            posList.append([pid, pos])
        self.posList = posList

    def load_rightprob(self, rightprob_file):
        rightProb = {}
        for line in open(rightprob_file):
            items = line.strip().split("\t")
            if len(items) != len(self.action_qid_dict) + 1:
                continue
            pid = int(items[0])
            if pid not in rightProb:
                rightProb[pid] = [float(x) for x in items[1:]]
        self.rightProb = rightProb

    def load_wrongprob(self, wrongprob_file):
        wrongProb = {}
        for line in open(wrongprob_file):
            items = line.strip().split("\t")
            if len(items) != len(self.action_qid_dict) + 1:
                continue
            pid = int(items[0])
            if pid not in wrongProb:
                wrongProb[pid] = [float(x) for x in items[1:]]
        self.wrongProb = wrongProb


if __name__ == "__main__":
    env = Q20Env()
    env.get_probs()
