import time
import numpy
import jax
import jax.numpy as jnp
from jaxopt import ScipyBoundedMinimize
import itertools
import copy
import warnings
from enum import Enum

class OptimizationMethod(str, Enum):
    JAXOPT = 'jaxopt'
    CHEAT_LOWD_VEC = 'cheat_lowd_vec'
    CHEAT_HIGHD_VEC_SINGLE_NEURONS = 'cheat_highd_vec_single_neurons'
    CHEAT_HIGHD_VEC_MANY_NEURONS = 'cheat_highd_vec_many_neurons'


class StimDesigner:
    def __init__(
            self,
            max_l0_norm=30,
            rng_seed=0,  # TODO: make this an rng
            should_log=False,
            lam_1=0.001,
            inter_stim_interval_generator=None,
            optimization_method=OptimizationMethod.JAXOPT,
            stim_timing_method='regular',
            initial_nostim_period=1,
            u_to_s_model_type='identity', # TODO: remove? it's used in sim_stim_design_stim
            n_identity_initialization=1,
    ):
        self.rng_seed = rng_seed
        self.rng = numpy.random.default_rng(rng_seed)
        assert max_l0_norm > 0
        self.max_l0_norm = max_l0_norm
        self.should_log = should_log
        self.lam_1 = lam_1
        self.u_to_s_model_type = u_to_s_model_type

        self.optimization_method: OptimizationMethod = optimization_method
        self.n_identity_initialization = n_identity_initialization
        self.stim_timing_method = stim_timing_method
        self.initial_nostim_period = initial_nostim_period

        if inter_stim_interval_generator is None:
            inter_stim_interval_generator = itertools.repeat(1)
        self.inter_stim_interval_generator = inter_stim_interval_generator
        self.last_stim_time = None
        self.current_isi = None

        self.log = []

        self.objective_history = []

    def stim_when_extreme(self, current_t, objective_value):
        self.objective_history.append(objective_value)
        return current_t > 50 and objective_value == numpy.nanmin(self.objective_history)

    def decide_whether_to_stim(self, current_t, **kwargs):
        if current_t < self.initial_nostim_period:
            return False

        if self.stim_timing_method == 'isi':  # or 'regular'
            if self.last_stim_time is None:
                self.last_stim_time = self.initial_nostim_period if self.initial_nostim_period is not None else 0
                self.current_isi = next(self.inter_stim_interval_generator)
            if current_t > self.last_stim_time + self.current_isi:
                self.last_stim_time = current_t
                self.current_isi = next(self.inter_stim_interval_generator)
                return True
            return False
        elif self.stim_timing_method == 'extreme':
            return self.stim_when_extreme(current_t, **kwargs)
        elif self.stim_timing_method == 'random':
            return kwargs['stim_time_rng'].random() < 1/next(self.inter_stim_interval_generator) * kwargs['input_array_dt']
        else:
            raise ValueError()


    @staticmethod
    def desired_stim_direction(equivalent_projection_matrix, stim_direction_type, rng):  # TODO: use built-in rng
        if stim_direction_type == 'first':
            desired_stim = numpy.zeros((equivalent_projection_matrix.shape[1], 1))
            desired_stim[0] = 1
        elif stim_direction_type == 'first2':
            desired_stim = numpy.zeros((equivalent_projection_matrix.shape[1], 2))
            desired_stim[0] = 1
            desired_stim[1] = 1
        elif stim_direction_type == 'col':
            desired_stim = numpy.zeros((equivalent_projection_matrix.shape[1], 1))
            desired_stim[rng.choice(equivalent_projection_matrix.shape[1]), 0] = 1
        elif stim_direction_type == 'random':
            desired_stim = rng.normal(size=(equivalent_projection_matrix.shape[1], 1))
            desired_stim = desired_stim / numpy.linalg.norm(desired_stim)
        elif stim_direction_type == 'random+':
            desired_stim_high_d = rng.normal(size=(equivalent_projection_matrix.shape[0], 1))
            desired_stim_high_d = desired_stim_high_d / numpy.linalg.norm(desired_stim_high_d)
            desired_stim_high_d = numpy.abs(desired_stim_high_d)

            desired_stim = equivalent_projection_matrix.T @ desired_stim_high_d
            desired_stim = desired_stim / numpy.linalg.norm(desired_stim)
        elif stim_direction_type == 'ones':
            desired_stim_high_d = numpy.ones((equivalent_projection_matrix.shape[0], 1))
            desired_stim = equivalent_projection_matrix.T @ desired_stim_high_d
            desired_stim = desired_stim / numpy.linalg.norm(desired_stim)
        elif stim_direction_type == '-ones':
            desired_stim_high_d = -numpy.ones((equivalent_projection_matrix.shape[0], 1))
            desired_stim = equivalent_projection_matrix.T @ desired_stim_high_d
            desired_stim = desired_stim / numpy.linalg.norm(desired_stim)
        else:
            raise ValueError()
        return desired_stim

    def register_stim(self):
        pass

    def design_stim_jaxopt(self, v, u_dimension, u_to_s_function=None):
        if u_to_s_function is None:
            u_to_s_function = lambda x: x

        u = self.rng.uniform(size=(u_dimension,)) * .1

        lb = jnp.zeros_like(u)
        ub = jnp.ones_like(u)
        bounds = (lb, ub)

        def objective(u):
            s = u_to_s_function(u)
            s_norm = jnp.linalg.norm(s)
            loss = self.lam_1 * (self.max_l0_norm - jnp.sum(jnp.abs(u)))
            loss += jnp.dot(s, v) / (s_norm + 1e-10)
            # new: loss += jnp.linalg.norm(jnp.dot(s, v))**2 / (s_norm + 1e-10)
            return -loss.reshape()

        runner = ScipyBoundedMinimize(fun=objective, method='l-bfgs-b')
        result = runner.run(u, bounds=bounds)
        u = numpy.array(result.params)

        if u.max() > 0:
            u = numpy.array(u / u.max())


        idx = numpy.argsort(u)
        u[idx[:-self.max_l0_norm]] = 0

        return u, {'s': u_to_s_function(u)}



    def design_stim(self, v, **kwargs):
        start_time = time.time()
        assert len(v.shape) == 2

        l = {}
        match self.optimization_method:
            case OptimizationMethod.JAXOPT:
                u, l = self.design_stim_jaxopt(v, kwargs['u_dimension'], kwargs['u_to_s_function'])
            case OptimizationMethod.CHEAT_LOWD_VEC:
                u = (kwargs['equivalent_projection_matrix'] @ v).flatten(),
            case OptimizationMethod.CHEAT_HIGHD_VEC_SINGLE_NEURONS:
                u = numpy.zeros(kwargs['equivalent_projection_matrix'].shape[0])
                u[self.rng.choice(kwargs['equivalent_projection_matrix'].shape[0])] = 1
            case OptimizationMethod.CHEAT_HIGHD_VEC_MANY_NEURONS:
                u = numpy.zeros(kwargs['equivalent_projection_matrix'].shape[0])
                u[self.rng.choice(kwargs['equivalent_projection_matrix'].shape[0], size=self.max_l0_norm, replace=False)] = 1
            case _:
                raise ValueError()


        if self.should_log:
            self.log.append({
                'optimization_time': time.time() - start_time,
                'v':v,
                'u':u,
                's': numpy.nan * v
            } | l)

        return u

    def sim_stim_design_stim(self, sr, stim_magnitude, desired_stim, equivalent_projection_matrix, current_t):
        self: StimDesigner
        optimization_method = self.optimization_method
        u_to_s_model_type = self.u_to_s_model_type
        if u_to_s_model_type == 'kernel_regressed' and sr.stim_reg.n_observed <= self.n_identity_initialization:
            u_to_s_model_type = 'identity'


        if optimization_method == 'jaxopt' and u_to_s_model_type == 'kernel_regressed':
            f = sr.stim_reg.make_jax_pred_f()
            pred = sr.autoreg.predict(n_steps=0)
            def u_to_s_function(u):
                return stim_magnitude * f(jax.numpy.hstack((pred, u)))
            designed_stim = self.design_stim(desired_stim, u_to_s_function=u_to_s_function, u_dimension=equivalent_projection_matrix.shape[0])
        elif optimization_method == 'jaxopt' and u_to_s_model_type == 'identity':
            def u_to_s_function(u):
                return stim_magnitude * equivalent_projection_matrix.T @ u
            designed_stim = self.design_stim(desired_stim, u_to_s_function=u_to_s_function, u_dimension=equivalent_projection_matrix.shape[0])
        elif optimization_method == 'cheat_lowd_vec' and u_to_s_model_type == 'identity':
            designed_stim = self.design_stim(desired_stim, equivalent_projection_matrix=equivalent_projection_matrix)
        elif optimization_method in {'cheat_highd_vec_single_neurons','cheat_highd_vec_many_neurons'} and u_to_s_model_type is None:
            designed_stim = self.design_stim(desired_stim, equivalent_projection_matrix=equivalent_projection_matrix)
        else:
            raise ValueError()

        self.log[-1]['stim_reg'] = copy.deepcopy(sr.stim_reg)
        self.log[-1]['time_of_stim'] = current_t
        self.log[-1]['equiv_proj_mat'] = equivalent_projection_matrix

        if (designed_stim == 0).all():
            designed_stim[0] = 1e-10
            warnings.warn("Stimulus was all zero!")  # TODO: handle this better

        return designed_stim
