import pdb
import numpy as np
import matplotlib.pyplot as plt
import ot
import matplotlib as mpl
import matplotlib.colors as mcolors
import itertools
from scipy.optimize import curve_fit


def exp_decay(x, A, k, C):
    return A * np.exp(-k * x) + C

# Parameters for plots
length_ticks = 2
font_size = 9
linewidth = 1.2
scatter_size = 2
length_ticks = 2
scatter_size = 20
horizontal_size = 1
vertical_size = 1
mpl.rcParams.update({'font.size': font_size})
mpl.rcParams['lines.linewidth'] = linewidth
mpl.rcParams['xtick.labelsize'] = font_size - 5
mpl.rcParams['ytick.labelsize'] = font_size - 5
mpl.rcParams['axes.spines.right'] = False
mpl.rcParams['axes.spines.top'] = False
mpl.rcParams['axes.titlesize'] = font_size - 2
mpl.rcParams['legend.fontsize'] = font_size - 2

mpl.use("TkAgg")

# Number of units representing each option
n_particles = 10

# Initial units are at high short rewards
particles_reference = np.zeros((n_particles * n_particles, 2))
particles_reference[:, 0] = 0
particles_reference[:, 1] = 10000
n_particles_total = n_particles * n_particles

# Parameters for weighting scheme
k_time = 6
k_magnitude = 12

# Uniform weights given to each unit to get the transport map
a = np.ones(n_particles_total) / n_particles_total  # initial weights
b = np.ones(n_particles_total) / n_particles_total  # target weights

amounts_imediate = np.linspace(0, 10000, 20000)

# Probabilities considered in the human experiments
probabilities_considered = np.array([0.05, 0.1, 0.4, 0.7, 0.9, 0.95])

# Delays considered in the human experiments
times_considered = [1, 6, 12, 36, 60, 120]

# Amounts considered in the human experiments
amounts_considered = [500, 10000]

# Plot style
symbols_plot = ['-o', ':^']
markerfacecolor_plot = ["black", 'grey']
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(2 * horizontal_size, vertical_size))

for i_amount_late, amount_late in enumerate(amounts_considered):

    certain_equivalents_time = []
    time_used = []

    # Simulate discount with reward time
    for time in times_considered:

        mean = np.array([time, amount_late])
        covariance = np.array([[1, 0], [0, 1]]) * 0.01

        # Sample points for delayed option
        particles_late = np.random.multivariate_normal(mean, covariance, int(n_particles_total * 0.5))

        for amount in amounts_imediate[::-1]:

            #print(amount)

            # Sample points for immediate option
            particles_immediate = np.random.multivariate_normal(np.array([0, amount]), covariance,
                                                                int(n_particles_total * 0.5))

            # Join immediate and delayed options
            particles_target = np.vstack([particles_immediate, particles_late])

            label_particles = np.vstack([particles_immediate * 0 + 10000, particles_late])



            value_late = 0
            value_early = 0
            # Estimate early and late values
            for i in range(len(particles_target)):
                magnitude_particle = particles_target[i, 1]
                time_particle = particles_target[i, 0]

                norm_time = time_particle / 120
                norm_magnitude = magnitude_particle / 10000

                if label_particles[i, 0] != 10000:
                    value_late += (particles_target[i, 1]) / (
                                (1 + k_time * norm_time) * (1 + k_magnitude * norm_magnitude))
                else:
                    value_early += (particles_target[i, 1]) / (
                                (1 + k_time * norm_time) * (1 + k_magnitude * norm_magnitude))

            # Save certain equivalent or subjective value
            if np.abs(value_late - value_early) < 50:
                ce = amount
                certain_equivalents_time.append(ce)
                time_used.append(time)
                break

    certain_equivalents_time = np.array(certain_equivalents_time)
    certain_equivalents_time = certain_equivalents_time / certain_equivalents_time[0]


    # Plot fitted temporal discount line
    ax[0].plot(time_used, certain_equivalents_time, symbols_plot[i_amount_late], color='black', markersize=4,markerfacecolor=markerfacecolor_plot[i_amount_late], label="Amount = " + str(amount_late))

    certain_equivalents_probability = []
    probability_used = []

    # Simulate discount with odds against reward
    for prob in probabilities_considered:  # 2,4,6,8

        mean = np.array([0, amount_late])
        covariance = np.array([[1, 0], [0, 1]]) * 0.01

        # Sample points of variable option
        samples_non_zero = np.random.multivariate_normal(mean, covariance, int(n_particles_total * 0.5 * prob))
        samples_zero = np.random.multivariate_normal(np.zeros(2), covariance,int(n_particles_total * 0.5) - int(n_particles_total * 0.5 * prob))

        particles_variable = np.zeros((int(n_particles_total * 0.5), 2))
        particles_variable[:int(n_particles_total * 0.5 * prob), :] = samples_non_zero
        particles_variable[int(n_particles_total * 0.5 * prob):, :] = samples_zero

        for amount in amounts_imediate[::-1]:

            #print(amount)

            # Sample certain option points
            particles_certain = np.random.multivariate_normal(np.array([0, amount]), covariance,
                                                              int(n_particles_total * 0.5))

            # Join certain and variable units
            particles_target = np.vstack([particles_certain, particles_variable])

            label_particles = np.vstack([particles_immediate * 0 + 10000, particles_variable])

            value_variable = 0
            value_certain = 0
            save_weights = []
            # Estimate certain and variable options
            for i in range(len(particles_target)):

                magnitude_particle = particles_target[i, 1]
                time_particle = particles_target[i, 0]

                norm_time = time_particle / 120
                norm_magnitude = magnitude_particle / 10000

                if label_particles[i, 0] != 10000:
                    value_variable += (particles_target[i, 1]) / (
                                (1 + k_time * norm_time) * (1 + k_magnitude * norm_magnitude))

                else:
                    value_certain += (particles_target[i, 1]) / (
                                (1 + k_time * norm_time) * (1 + k_magnitude * norm_magnitude))

            if np.abs(value_variable - value_certain) < 50:
                ce = amount
                certain_equivalents_probability.append(ce)
                probability_used.append(prob)
                break

    certain_equivalents_probability = np.array(certain_equivalents_probability)
    certain_equivalents_probability = certain_equivalents_probability / certain_equivalents_probability[-1]
    probability_used = np.array(probability_used)
    axis=(1 - probability_used) / probability_used


    # Plot
    ax[1].plot(axis, certain_equivalents_probability, symbols_plot[i_amount_late],color='black', markersize=4, markerfacecolor=markerfacecolor_plot[i_amount_late],label="Amount = " + str(amount_late))

ax[0].set_xlabel("Delay (months)")
ax[0].set_ylabel("Subjective Value")

ax[1].set_xlabel("Odds Against")
ax[1].set_ylabel("Subjective Value")
ax[1].legend()

plt.show()

pdb.set_trace()