
import numpy as np
from scipy.optimize import minimize, NonlinearConstraint, LinearConstraint, Bounds
from numpy.linalg import matrix_power
import sys

import numpy as np


def check_T(T):
    for s in [0,1]:
        for d in [0,1]:
            for y in [0,1]:
                for x in range(4):
                    if sum(T[s][d][y][x]).round(4) != 1:
                        # print(T[s][d][y])
                        # print(sum(T[s][d][y][x]).round(4) )
                        print("!!!!! Does not sum up", s, d, y, x)
                        raise ValueError("Does not sum up")

def assemble_T(T000, T001, T010, T011, T100, T101, T110, T111):

    T00 = np.array([T000, T001])
    T01 = np.array([T010, T011])
    # careful, here copy pasting 0
    T10 = np.array([T100, T101])
    T11 = np.array([T110, T111])

    # Ts
    T0 = np.array([T00, T01])
    T1 = np.array([T10, T11])
    # T
    T = np.array([T0, T1])
    return T


def get_Pxx_given_default():
    # [s][d][y][x]
    # S=0, D= 0, Y=0
    T000 = np.array(
            [[0.9, 0.03333, 0.03333, 0.03333], [0.03333, 0.9, 0.03333, 0.03333], [0.03333, 0.03333, 0.9, 0.03333],
         [0.03333, 0.03333, 0.03333, 0.9]])
    # S=0, D= 0, Y=1
    T001 = np.array(
        [[0.9, 0.03333, 0.03333, 0.03333], [0.03333, 0.9, 0.03333, 0.03333], [0.03333, 0.03333, 0.9, 0.03333],
         [0.03333, 0.03333, 0.03333, 0.9]])

    # stays the same  - one-sided
    # S=1, D= 0, Y=0
    T100 = np.array(
        [[0.9, 0.03333, 0.03333, 0.03333], [0.03333, 0.9, 0.03333, 0.03333], [0.03333, 0.03333, 0.9, 0.03333],
         [0.03333, 0.03333, 0.03333, 0.9]])
    # S=1, D= 0, Y=1
    T101 = np.array(
        [[0.9, 0.03333, 0.03333, 0.03333], [0.03333, 0.9, 0.03333, 0.03333], [0.03333, 0.03333, 0.9, 0.03333],
         [0.03333, 0.03333, 0.03333, 0.9]])

    # changes
    # S=0, D= 1, Y=0
    T010 = np.array(
        [[0.9, 0.03333, 0.03333, 0.03333], [0.9, 0.03333, 0.03333, 0.03333], [0.9, 0.03333, 0.03333, 0.03333],
         [0.9, 0.03333, 0.03333, 0.03333]])
    # S=0, D= 1, Y=1
    T011 = np.array(
        [[0.33333, 0.6, 0.03333, 0.03333], [0.03333, 0.33333, 0.6, 0.03333], [0.03333, 0.03333, 0.33333, 0.6],
         [0.03333, 0.03333, 0.03333, 0.9]])
    # S=1, D= 1, Y=0
    T110 = np.array(
        [[0.9, 0.03333, 0.03333, 0.03333], [0.9, 0.03333, 0.03333, 0.03333], [0.9, 0.03333, 0.03333, 0.03333],
         [0.9, 0.03333, 0.03333, 0.03333]])
    # S=0, D= 1, Y=1
    T111 = np.array(
        [[0.53333, 0.4, 0.03333, 0.03333], [0.03333, 0.53333, 0.4, 0.03333], [0.03333, 0.03333, 0.53333, 0.4],
         [0.03333, 0.03333, 0.03333, 0.9]])

    T = assemble_T(T000, T001, T010, T011, T100, T101, T110, T111)
    check_T(T)
    return T

    return 0

def get_Pxx_given_oneslow():
    # [s][d][y][x]
    # S=0, D= 0, Y=0
    T000 = np.array(
        [[0.9, 0.03333, 0.03333, 0.03333], [0.03333, 0.9, 0.03333, 0.03333], [0.03333, 0.03333, 0.9, 0.03333],
         [0.03333, 0.03333, 0.03333, 0.9]])
    # S=0, D= 0, Y=1
    T001 = np.array(
        [[0.9, 0.03333, 0.03333, 0.03333], [0.03333, 0.9, 0.03333, 0.03333], [0.03333, 0.03333, 0.9, 0.03333],
         [0.03333, 0.03333, 0.03333, 0.9]])

    # stays the same  - one-sided
    # S=1, D= 0, Y=0
    T100 = np.array(
        [[0.9, 0.03333, 0.03333, 0.03333], [0.03333, 0.9, 0.03333, 0.03333], [0.03333, 0.03333, 0.9, 0.03333],
         [0.03333, 0.03333, 0.03333, 0.9]])
    # S=1, D= 0, Y=1
    T101 = np.array(
        [[0.9, 0.03333, 0.03333, 0.03333], [0.03333, 0.9, 0.03333, 0.03333], [0.03333, 0.03333, 0.9, 0.03333],
         [0.03333, 0.03333, 0.03333, 0.9]])

    # changes
    # S=0, D= 1, Y=0
    T010 = np.array(
        [[0.9, 0.03333, 0.03333, 0.03333], [0.9, 0.03333, 0.03333, 0.03333], [0.9, 0.03333, 0.03333, 0.03333],
         [0.9, 0.03333, 0.03333, 0.03333]])
    # S=0, D= 1, Y=1
    T011 = np.array(
        [[0.53333, 0.4, 0.03333, 0.03333], [0.03333, 0.53333, 0.4, 0.03333], [0.03333, 0.03333, 0.53333, 0.4],
         [0.03333, 0.03333, 0.03333, 0.9]])
    # S=1, D= 1, Y=0
    T110 = np.array(
        [[0.9, 0.03333, 0.03333, 0.03333], [0.9, 0.03333, 0.03333, 0.03333], [0.9, 0.03333, 0.03333, 0.03333],
         [0.9, 0.03333, 0.03333, 0.03333]])
    # S=0, D= 1, Y=1
    T111 = np.array(
        [[0.53333, 0.4, 0.03333, 0.03333], [0.03333, 0.53333, 0.4, 0.03333], [0.03333, 0.03333, 0.53333, 0.4],
         [0.03333, 0.03333, 0.03333, 0.9]])

    T = assemble_T(T000, T001, T010, T011, T100, T101, T110, T111)
    check_T(T)
    return T

    return 0

def get_Pxx_given_onefast():
    #[s][d][y][x]
    # S=0, D= 0, Y=0
    T000 = np.array([[0.9, 0.03333, 0.03333, 0.03333], [0.03333, 0.9,0.03333, 0.03333],[0.03333,0.03333, 0.9, 0.03333], [0.03333, 0.03333, 0.03333, 0.9] ])
    # S=0, D= 0, Y=1
    T001 = np.array([[0.9, 0.03333, 0.03333, 0.03333], [0.03333, 0.9,0.03333, 0.03333],[0.03333,0.03333,  0.9, 0.03333], [0.03333, 0.03333, 0.03333, 0.9] ])

    # stays the same  - one-sided
    # S=1, D= 0, Y=0
    T100 = np.array(
        [[0.9, 0.03333, 0.03333, 0.03333], [0.03333, 0.9, 0.03333, 0.03333], [0.03333, 0.03333, 0.9, 0.03333],
         [0.03333, 0.03333, 0.03333, 0.9]])
    # S=1, D= 0, Y=1
    T101 = np.array(
        [[0.9, 0.03333, 0.03333, 0.03333], [0.03333, 0.9, 0.03333, 0.03333], [0.03333, 0.03333, 0.9, 0.03333],
         [0.03333, 0.03333, 0.03333, 0.9]])


    # changes
    # S=0, D= 1, Y=0
    T010 = np.array([[0.9, 0.03333, 0.03333, 0.03333], [0.9, 0.03333, 0.03333, 0.03333],[0.9, 0.03333, 0.03333, 0.03333], [0.9, 0.03333, 0.03333, 0.03333] ])
    # S=0, D= 1, Y=1
    T011 = np.array([[0.13333, 0.8, 0.03333, 0.03333], [0.03333, 0.13333, 0.8, 0.03333],[0.03333, 0.03333, 0.13333, 0.8], [0.03333, 0.03333, 0.03333, 0.9] ])
    # S=1, D= 1, Y=0
    T110 = np.array([[0.9, 0.03333, 0.03333, 0.03333], [0.9, 0.03333, 0.03333, 0.03333],[0.9, 0.03333, 0.03333, 0.03333], [0.9, 0.03333, 0.03333, 0.03333] ])
    # S=0, D= 1, Y=1
    T111 = np.array([[0.13333, 0.8, 0.03333, 0.03333], [0.03333, 0.13333, 0.8, 0.03333],[0.03333, 0.03333, 0.13333, 0.8], [0.03333, 0.03333, 0.03333, 0.9] ])

    T = assemble_T(T000, T001, T010, T011, T100, T101, T110, T111)
    check_T(T)
    return T

def get_Pxx_given_onemed():

    #[s][d][y][x]
    # S=0, D= 0, Y=0
    T000 = np.array([[0.9, 0.03333, 0.03333, 0.03333], [0.03333, 0.9,0.03333, 0.03333],[0.03333,0.03333, 0.9, 0.03333], [0.03333, 0.03333, 0.03333, 0.9] ])
    # S=0, D= 0, Y=1
    T001 = np.array([[0.9, 0.03333, 0.03333, 0.03333], [0.03333, 0.9,0.03333, 0.03333],[0.03333,0.03333,  0.9, 0.03333], [0.03333, 0.03333, 0.03333, 0.9] ])

    # S=1, D= 0, Y=0
    T100 = np.array(
        [[0.9, 0.03333, 0.03333, 0.03333], [0.03333, 0.9, 0.03333, 0.03333], [0.03333, 0.03333, 0.9, 0.03333],
         [0.03333, 0.03333, 0.03333, 0.9]])
    # S=1, D= 0, Y=1
    T101 = np.array(
        [[0.9, 0.03333, 0.03333, 0.03333], [0.03333, 0.9, 0.03333, 0.03333], [0.03333, 0.03333, 0.9, 0.03333],
         [0.03333, 0.03333, 0.03333, 0.9]])


    # S=0, D= 1, Y=0
    T010 = np.array([[0.9, 0.03333, 0.03333, 0.03333], [0.9, 0.03333, 0.03333, 0.03333],[0.9, 0.03333, 0.03333, 0.03333], [0.9, 0.03333, 0.03333, 0.03333] ])
    # S=0, D= 1, Y=1
    T011 = np.array([[0.33333, 0.6, 0.03333, 0.03333], [0.03333, 0.33333, 0.6, 0.03333],[0.03333, 0.03333, 0.33333, 0.6], [0.03333, 0.03333, 0.03333, 0.9] ])
       # S=1, D= 1, Y=0
    T110 = np.array([[0.9, 0.03333, 0.03333, 0.03333], [0.9, 0.03333, 0.03333, 0.03333],[0.9, 0.03333, 0.03333, 0.03333], [0.9, 0.03333, 0.03333, 0.03333] ])
    # S=0, D= 1, Y=1
    T111 = np.array([[0.33333, 0.6, 0.03333, 0.03333], [0.03333, 0.33333, 0.6, 0.03333],[0.03333, 0.03333, 0.33333, 0.6], [0.03333, 0.03333, 0.03333, 0.9] ])

    T = assemble_T(T000, T001, T010, T011, T100, T101, T110, T111)
    check_T(T)
    return T

def get_Pxx_given_twomed_recourse():

    #[s][d][y][x]
    # S=0, D= 0, Y=0
    T000 = np.array([[0.7, 0.23333, 0.03333, 0.03333], [0.03333, 0.7,0.23333, 0.03333],[0.03333,0.03333, 0.7, 0.23333], [0.03333, 0.03333, 0.03333, 0.9] ])
    # S=0, D= 0, Y=1
    T001 = np.array([[0.7, 0.23333, 0.03333, 0.03333], [0.03333, 0.7,0.23333, 0.03333],[0.03333,0.03333,  0.7, 0.23333], [0.03333, 0.03333, 0.03333, 0.9] ])

    # S=1, D= 0, Y=0
    T100 = np.array(
        [[0.5, 0.43333, 0.03333, 0.03333], [0.03333, 0.5, 0.43333, 0.03333], [0.03333, 0.03333, 0.5, 0.43333],
         [0.03333, 0.03333, 0.03333, 0.9]])
    # S=1, D= 0, Y=1
    T101 = np.array(
        [[0.5, 0.43333, 0.03333, 0.03333], [0.03333, 0.5, 0.43333, 0.03333], [0.03333, 0.03333, 0.5, 0.43333],
         [0.03333, 0.03333, 0.03333, 0.9]])


    # S=0, D= 1, Y=0
    T010 = np.array([[0.9, 0.03333, 0.03333, 0.03333], [0.9, 0.03333, 0.03333, 0.03333],[0.9, 0.03333, 0.03333, 0.03333], [0.9, 0.03333, 0.03333, 0.03333] ])
    # S=0, D= 1, Y=1
    T011 = np.array([[0.33333, 0.6, 0.03333, 0.03333], [0.03333, 0.33333, 0.6, 0.03333],[0.03333, 0.03333, 0.33333, 0.6], [0.03333, 0.03333, 0.03333, 0.9] ])
       # S=1, D= 1, Y=0
    T110 = np.array([[0.9, 0.03333, 0.03333, 0.03333], [0.9, 0.03333, 0.03333, 0.03333],[0.9, 0.03333, 0.03333, 0.03333], [0.9, 0.03333, 0.03333, 0.03333] ])
    # S=0, D= 1, Y=1
    T111 = np.array([[0.33333, 0.6, 0.03333, 0.03333], [0.03333, 0.33333, 0.6, 0.03333],[0.03333, 0.03333, 0.33333, 0.6], [0.03333, 0.03333, 0.03333, 0.9] ])

    T = assemble_T(T000, T001, T010, T011, T100, T101, T110, T111)
    check_T(T)
    return T

def get_Pxx_given_twomed_discouraged():

    #[s][d][y][x]
    # S=0, D= 0, Y=0
    T000 = np.array([[0.9, 0.03333, 0.03333, 0.03333], [0.63333, 0.3, 0.03333, 0.03333],[0.13333,0.53333, 0.3, 0.03333], [0.03333, 0.23333, 0.43333, 0.3] ])
    # S=0, D= 0, Y=1
    T001 = np.array([[0.9, 0.03333, 0.03333, 0.03333], [0.63333, 0.3, 0.03333, 0.03333],[0.13333,0.53333, 0.3, 0.03333], [0.03333, 0.23333, 0.43333, 0.3] ])

    # S=1, D= 0, Y=0
    T100 = np.array(
        [[0.9, 0.03333, 0.03333, 0.03333], [0.43333, 0.5, 0.03333, 0.03333],[0.13333,0.33333, 0.5, 0.03333], [0.03333, 0.23333, 0.23333, 0.5] ])
    # S=1, D= 0, Y=1
    T101 = np.array(
        [[0.9, 0.03333, 0.03333, 0.03333], [0.43333, 0.5, 0.03333, 0.03333],[0.13333,0.33333, 0.5, 0.03333], [0.03333, 0.23333, 0.23333, 0.5] ])


    # S=0, D= 1, Y=0
    T010 = np.array([[0.9, 0.03333, 0.03333, 0.03333], [0.9, 0.03333, 0.03333, 0.03333],[0.9, 0.03333, 0.03333, 0.03333], [0.9, 0.03333, 0.03333, 0.03333] ])
    # S=0, D= 1, Y=1
    T011 = np.array([[0.33333, 0.6, 0.03333, 0.03333], [0.03333, 0.33333, 0.6, 0.03333],[0.03333, 0.03333, 0.33333, 0.6], [0.03333, 0.03333, 0.03333, 0.9] ])
       # S=1, D= 1, Y=0
    T110 = np.array([[0.9, 0.03333, 0.03333, 0.03333], [0.9, 0.03333, 0.03333, 0.03333],[0.9, 0.03333, 0.03333, 0.03333], [0.9, 0.03333, 0.03333, 0.03333] ])
    # S=0, D= 1, Y=1
    T111 = np.array([[0.33333, 0.6, 0.03333, 0.03333], [0.03333, 0.33333, 0.6, 0.03333],[0.03333, 0.03333, 0.33333, 0.6], [0.03333, 0.03333, 0.03333, 0.9] ])

    T = assemble_T(T000, T001, T010, T011, T100, T101, T110, T111)
    check_T(T)
    return T


def get_Pxx_random(num_cat=4, seed=6):

# fix numpy seed
    np.random.seed(seed)

    # get an empty array of dimensions (2, 2, 2, num_cat, num_cat).
    Pxx_given = np.zeros((2, 2, 2, num_cat, num_cat))


    for i in range(2):
        for j in range(2):
            for k in range(2):
                # get an empty  array named Test that has the dimension (num_cat,num_cat).
                Test = np.zeros((num_cat,num_cat))
                # fill the array with numbers between 0 and 1, such that each row sums to 1
                for m in range(num_cat):
                    Test[m,:] = np.random.dirichlet(np.ones(num_cat),size=1)


                Pxx_given[i][j][k] = Test


    check_T(Pxx_given)

    return Pxx_given

def get_Pxx_estimated(estimation):
    # print("Got PXX estimated")
    if estimation == '_est-random':
        return np.load('estimations/Pxx_est_random.npy')
    elif estimation == "_est-threshold":
        return np.load('estimations/Pxx_est_threshold.npy')
    elif estimation == "_est-biased":
        return np.load('estimations/Pxx_est_biased.npy')
    else:
        raise NotImplementedError