import constants as C
import frac as F

import strings as S
from scipy.special import binom
import numpy as np

import pandas as pd
import json
import csv


def _generate_True_Distribution_choice(config: dict, subscenario: str):
    number_of_choices = config[S.prompt_parameter.NUMBER]

    distribution = {}

    min_value = ord("A")

    for value in range(min_value, min_value + number_of_choices):
        distribution[chr(value)] = F.Frac(1, number_of_choices)

    return distribution

def _generate_True_Distribution_stats(config: dict, subscenario: str):
    params = config["config"]["params"]

    # for now: person is never in the subpopulation
    person_is_in_subpop = False
    share_working_in_subpopulation = params["share_working_in_subpopulation"]
    proportions = params["proportions"]
    proportions_in_subgroup = params["proportions_in_subgroup"]
    condition = params["asking_for_condition"]
    asking_for_condition_idx = params["conditions"].index(condition)

    distribution = {}
    for condition_idx, condition in enumerate(params["conditions"]):
        if person_is_in_subpop:
            assert False, "Not implemented yet"
            # distribution[condition] = proportions_in_subgroup[condition_idx]
        else:
            prevalence_in_subpop: int = proportions_in_subgroup[condition_idx]
            res = int(round(proportions[condition_idx] * (1 - share_working_in_subpopulation / 100) + prevalence_in_subpop * (share_working_in_subpopulation / 100), 0))
            distribution[condition] = F.Frac(res, 100)

    return distribution


def _generate_True_Distribution_coins(config, subscenario: str):
    if subscenario == S.subscenario_names.REPEATED_DEPENDENT:
        base = int(config[S.prompt_parameter.PREVIOUS_RESULT])
    else:
        base = 0

    bias = config[S.prompt_parameter.BIAS]
    number_of_coins = config[S.prompt_parameter.NUMBER]

    pdom = F.Frac(bias, bias + 1)
    oneminuspdom = F.Frac(1, 1 + bias)
    distribution = {}
    for i in range(number_of_coins + 1):
        p = F.Frac(int(binom(number_of_coins, i)), 1)
        p = p.mult(pdom.power(i))
        p = p.mult(oneminuspdom.power(number_of_coins - i))

        distribution[str(i + base)] = p

    return distribution


def _generate_True_Distribution_preference(config, subscenario: str):
    options = config[S.prompt_parameter.OPTIONS]
    focus = config[S.prompt_parameter.FOCUS]
    bias = config[S.prompt_parameter.BIAS]

    other = None
    for option in options:
        if option != focus:
            other = option

    distribution = {focus: F.Frac(bias, 1 + bias), other: F.Frac(1, 1 + bias)}

    return distribution


def _generate_True_Distribution_dice(config, subscenario: str):
    number_of_dice = config[S.prompt_parameter.NUMBER]
    face_per_dice = config[S.prompt_parameter.FACES]

    dice = np.arange(1, face_per_dice + 1)

    if subscenario == S.subscenario_names.REPEATED_DEPENDENT:
        base = int(config[S.prompt_parameter.PREVIOUS_RESULT])
    else:
        base = 0

    values = base * np.ones(1)
    for _ in range(number_of_dice):
        values = values[None, :] + dice[:, None]
        values = values.flatten()
    true_distrib = F.get_valid_frac_from_couting(values)

    if subscenario != S.subscenario_names.OBSERVATIONS:
        true_distrib = {str(k): v for k, v in true_distrib.items()}
        return true_distrib

    observations = config[S.prompt_parameter.OBSERVATIONS]

    dice_values = np.array([int(u) for u in true_distrib.keys()])

    min_value = np.min(dice_values)
    max_value = np.max(dice_values)
    midval = (min_value + max_value) // 2
    new_distrib = {}
    mass = F.Frac(num=0, den=1)

    for v in true_distrib:
        value = int(v)
        valid_value = True
        for observation in observations:
            if (observation == S.observation_names.SMALLER_THAN_MIDDLE) & (
                value >= midval
            ):
                valid_value = False
            if (observation == S.observation_names.LARGER_THAN_MIDDLE) & (
                value <= midval
            ):
                valid_value = False
            if (observation == S.observation_names.NOT_MIDDLE) & (value == midval):
                valid_value = False
            if (observation == S.observation_names.NOT_ONE) & (value == min_value):
                valid_value = False
            if (observation == S.observation_names.ODD) & (value % 2 == 0):
                valid_value = False
            if (observation == S.observation_names.EVEN) & (value % 2 == 1):
                valid_value = False

        if valid_value:
            new_distrib[v] = true_distrib[v]
            mass = mass.add(true_distrib[v])

    assert mass.num > 0, "No valid value for the distribution after observations"

    for v in new_distrib:
        new_distrib[v] = new_distrib[v].divide(other_frac=mass)

    new_distrib = {str(k): v for k, v in new_distrib.items()}

    return new_distrib


def generateTrueDistribution(config: dict):
    scenario = config[S.prompt_parameter.SCENARIO]
    subscenario = config[S.prompt_parameter.SUBSCENARIO]
    subconfig = config[scenario]
    if scenario == S.scenario_names.CHOICE:
        return _generate_True_Distribution_choice(subconfig, subscenario)
    elif scenario == S.scenario_names.COINS:
        return _generate_True_Distribution_coins(subconfig, subscenario)
    elif scenario == S.scenario_names.PREFERENCES:
        return _generate_True_Distribution_preference(subconfig, subscenario)
    elif scenario == S.scenario_names.DICE:
        return _generate_True_Distribution_dice(subconfig, subscenario)
    elif scenario == S.scenario_names.STATS:
        return _generate_True_Distribution_stats(subconfig, subscenario)
    else:
        raise (ValueError("Unknown scenario: " + scenario))


def find_maximum_likelihoo(distribution):
    argmax = None
    argvalue = F.Frac(0, 1)
    for key in distribution:
        v: F.Frac = distribution[key]
        if (argmax is None) or (v.bigger_than(argvalue)):
            argmax = key
            argvalue = v

    for key in distribution:
        if key == argmax:
            continue
        v: F.Frac = distribution[key]
        if not (argvalue.bigger_than(v)):
            return False, argmax

    return True, argmax


if __name__ == "__main__":
    config_choice = {
        S.prompt_parameter.SCENARIO: S.scenario_names.CHOICE,
        S.prompt_parameter.SUBSCENARIO: S.subscenario_names.REPEATED_DEPENDENT,
        S.scenario_names.CHOICE: {
            S.prompt_parameter.NUMBER: 7,
            S.prompt_parameter.PREVIOUS_RESULT: "A",
        },
    }

    config_scenario = {
        S.prompt_parameter.SCENARIO: S.scenario_names.PREFERENCES,
        S.prompt_parameter.SUBSCENARIO: S.subscenario_names.REPEATED_DEPENDENT,
        S.scenario_names.PREFERENCES: {
            S.prompt_parameter.OPTIONS: ["Left", "Right"],
            S.prompt_parameter.FOCUS: "Left",
            S.prompt_parameter.BIAS: 3,
            S.prompt_parameter.PREVIOUS_RESULT: "Right",
        },
    }

    config_dice = {
        S.prompt_parameter.SCENARIO: S.scenario_names.DICE,
        S.prompt_parameter.SUBSCENARIO: S.subscenario_names.OBSERVATIONS,
        S.scenario_names.DICE: {
            S.prompt_parameter.NUMBER: 4,
            S.prompt_parameter.FACES: 4,
            S.prompt_parameter.PREVIOUS_RESULT: 0,
            S.prompt_parameter.OBSERVATIONS: [S.observation_names.EVEN],
        },
    }

    config_coins = {
        S.prompt_parameter.SCENARIO: S.scenario_names.COINS,
        S.prompt_parameter.SUBSCENARIO: S.subscenario_names.REPEATED_DEPENDENT,
        S.scenario_names.COINS: {
            S.prompt_parameter.NUMBER: 4,
            S.prompt_parameter.FOCUS: S.coin_faces.TAILS,
            S.prompt_parameter.BIAS: 2,
            S.prompt_parameter.PREVIOUS_RESULT: 7,
        },
    }

    config_stats = {
        S.prompt_parameter.SCENARIO: S.scenario_names.STATS,
        S.prompt_parameter.SUBSCENARIO: S.subscenario_names.REGULAR,
        S.scenario_names.STATS: {
            "idx": 0,  # TMP
        },
    }

    for config in [config_dice, config_choice, config_scenario, config_coins, config_stats]:
        print(config[S.prompt_parameter.SCENARIO])

        d = generateTrueDistribution(config)
        for k in d:
            print(k, str(d[k]))
        print(find_maximum_likelihoo(d))

        print("===" * 10)
