from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import scipy.stats as stats
import torch
from typing import Union
import gurobipy as gp 
from gurobipy import GRB
TORCH_DEVICE = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
class Optimizer:
    def __init__(self, *args, **kwargs):
        pass

    def setup(self, cost_function):
        raise NotImplementedError("Must be implemented in subclass.")

    def reset(self):
        raise NotImplementedError("Must be implemented in subclass.")

    def obtain_solution(self, *args, **kwargs):
        raise NotImplementedError("Must be implemented in subclass.")


class CEMOptimizer(Optimizer):

    def __init__(self, sol_dim, max_iters, popsize, num_elites, cost_function, _predict_next_obs, obs_postproc2, 
                 constraint_type, constraint_ub, action_dim, obs_dim, beta, arc_step,
                 upper_bound=None, lower_bound=None, epsilon=0.001, alpha=0.25):
        """Creates an instance of this class.

        Arguments:
            sol_dim (int): The dimensionality of the problem space
            max_iters (int): The maximum number of iterations to perform during optimization
            popsize (int): The number of candidate solutions to be sampled at every iteration
            num_elites (int): The number of top solutions that will be used to obtain the distribution
                at the next iteration.
            upper_bound (np.array): An array of upper bounds
            lower_bound (np.array): An array of lower bounds
            epsilon (float): A minimum variance. If the maximum variance drops below epsilon, optimization is
                stopped.
            alpha (float): Controls how much of the previous mean and variance is used for the next iteration.
                next_mean = alpha * old_mean + (1 - alpha) * elite_mean, and similarly for variance.
        """
        super().__init__()
        self.sol_dim, self.max_iters, self.popsize, self.num_elites = sol_dim, max_iters, popsize, num_elites

        self.constraint_type = constraint_type
        self.ub, self.lb = upper_bound, lower_bound
        self.epsilon, self.alpha = epsilon, alpha
        self._predict_next_obs = _predict_next_obs
        self.obs_postproc2 = obs_postproc2
        self.cost_function = cost_function
        self.constraint_ub = constraint_ub
        self.action_dim = action_dim
        self.obs_dim = obs_dim
        if num_elites > popsize:
            raise ValueError("Number of elites must be at most the population size.")
        self.beta = beta
        self.arc_step = arc_step

    def reset(self):
        pass

    def L2_Projection(self, action):
        with gp.Env(empty=True) as env:
            env.setParam('OutputFlag', 0)
            env.start()
            obj = 0
            constr = 0
            with gp.Model(env=env) as reacher_m:
                for i in range(action.shape[0]):
                    globals()['neta'+str(i)] = action[i]
                    globals()['a'+str(i)] = reacher_m.addVar(lb=self.lb[i], ub=self.ub[i], name=f"a{i}", vtype=GRB.CONTINUOUS)
                    obj += (globals()['a'+str(i)]-globals()['neta'+str(i)])**2
                    constr += globals()['a'+str(i)]**2
                reacher_m.setObjective(obj, GRB.MINIMIZE)
                reacher_m.addConstr(constr <= self.constraint_ub)

                reacher_m.optimize()

                return torch.tensor([globals()['a'+str(i)].X for i in range(action.shape[0])])

    def HC_O_Projection(self, action, state):
        with gp.Env(empty=True) as env:
            env.setParam('OutputFlag', 0)
            env.start()
            with gp.Model(env=env) as half_m:
                neta1 = action[0]
                neta2 = action[1]
                neta3 = action[2]
                neta4 = action[3]
                neta5 = action[4]
                neta6 = action[5]
                w1 = state[11]
                w2 = state[12]
                w3 = state[13]
                w4 = state[14]
                w5 = state[15]
                w6 = state[16]
                a1 = half_m.addVar(lb=-1, ub=1, name="a1", vtype=GRB.CONTINUOUS)
                a2 = half_m.addVar(lb=-1, ub=1, name="a2", vtype=GRB.CONTINUOUS)
                a3 = half_m.addVar(lb=-1, ub=1, name="a3", vtype=GRB.CONTINUOUS)
                a4 = half_m.addVar(lb=-1, ub=1, name="a4", vtype=GRB.CONTINUOUS)
                a5 = half_m.addVar(lb=-1, ub=1, name="a5", vtype=GRB.CONTINUOUS)
                a6 = half_m.addVar(lb=-1, ub=1, name="a6", vtype=GRB.CONTINUOUS)
                v = half_m.addVar(ub=self.constraint_ub, name="v")
                u1 = half_m.addVar(lb=-gp.GRB.INFINITY, name="u1")
                u2 = half_m.addVar(lb=-gp.GRB.INFINITY, name="u2")
                u3 = half_m.addVar(lb=-gp.GRB.INFINITY, name="u3")
                u4 = half_m.addVar(lb=-gp.GRB.INFINITY, name="u4")
                u5 = half_m.addVar(lb=-gp.GRB.INFINITY, name="u5")
                u6 = half_m.addVar(lb=-gp.GRB.INFINITY, name="u6")
                abs_u1 = half_m.addVar(ub=20, name="abs_u1")
                abs_u2 = half_m.addVar(ub=20, name="abs_u2")
                abs_u3 = half_m.addVar(ub=20, name="abs_u3")
                abs_u4 = half_m.addVar(ub=20, name="abs_u4")
                abs_u5 = half_m.addVar(ub=20, name="abs_u5")
                abs_u6 = half_m.addVar(ub=20, name="abs_u6")
                obj = (a1-neta1)**2 + (a2-neta2)**2 + (a3-neta3)**2 + \
                    (a4-neta4)**2 + (a5-neta5)**2 + (a6-neta6)**2
                half_m.setObjective(obj, GRB.MINIMIZE)

                half_m.addConstr(u1 == a1*w1)
                half_m.addConstr(u2 == a2*w2)
                half_m.addConstr(u3 == a3*w3)
                half_m.addConstr(u4 == a4*w4)
                half_m.addConstr(u5 == a5*w5)
                half_m.addConstr(u6 == a6*w6)
                half_m.addConstr(abs_u1 == (gp.abs_(u1)))
                half_m.addConstr(abs_u2 == (gp.abs_(u2)))
                half_m.addConstr(abs_u3 == (gp.abs_(u3)))
                half_m.addConstr(abs_u4 == (gp.abs_(u4)))
                half_m.addConstr(abs_u5 == (gp.abs_(u5)))
                half_m.addConstr(abs_u6 == (gp.abs_(u6)))
                half_m.addConstr((abs_u1 + abs_u2 + abs_u3 +
                                abs_u4 + abs_u5 + abs_u6) == v)

                half_m.optimize()
                return torch.tensor([a1.X, a2.X, a3.X, a4.X, a5.X, a6.X])

    def H_M_Projection(self, action, state):
        with gp.Env(empty=True) as env:
            env.setParam('OutputFlag', 0)
            env.start()
            with gp.Model(env=env) as hop_m:
                neta1 = action[0]
                neta2 = action[1]
                neta3 = action[2]
                w1 = state[8]
                w2 = state[9]
                w3 = state[10]
                a1 = hop_m.addVar(lb=-1, ub=1, name="a1", vtype=GRB.CONTINUOUS)
                a2 = hop_m.addVar(lb=-1, ub=1, name="a2", vtype=GRB.CONTINUOUS)
                a3 = hop_m.addVar(lb=-1, ub=1, name="a3", vtype=GRB.CONTINUOUS)
                v = hop_m.addVar(ub=self.constraint_ub, name="v")
                u1 = hop_m.addVar(lb=-gp.GRB.INFINITY, name="u1")
                u2 = hop_m.addVar(lb=-gp.GRB.INFINITY, name="u2")
                u3 = hop_m.addVar(lb=-gp.GRB.INFINITY, name="u3")
                abs_u1 = hop_m.addVar(ub=10, name="abs_u1")
                abs_u2 = hop_m.addVar(ub=10, name="abs_u2")
                abs_u3 = hop_m.addVar(ub=10, name="abs_u3")
                obj = (a1-neta1)**2 + (a2-neta2)**2 + (a3-neta3)**2
                hop_m.setObjective(obj, GRB.MINIMIZE)

                hop_m.addConstr(u1 == a1*w1)
                hop_m.addConstr(u2 == a2*w2)
                hop_m.addConstr(u3 == a3*w3)
                hop_m.addConstr(abs_u1 == (gp.max_(u1,0)))
                hop_m.addConstr(abs_u2 == (gp.max_(u2,0)))
                hop_m.addConstr(abs_u3 == (gp.max_(u3,0)))
                hop_m.addConstr((abs_u1 + abs_u2 + abs_u3) == v)

                hop_m.optimize()
                return torch.tensor([a1.X, a2.X, a3.X])

    def W_M_Projection(self, action, state):
        with gp.Env(empty=True) as env:
            env.setParam('OutputFlag', 0)
            env.start()
            with gp.Model(env=env) as walk_m:
                neta1 = action[0]
                neta2 = action[1]
                neta3 = action[2]
                neta4 = action[3]
                neta5 = action[4]
                neta6 = action[5]
                w1 = state[12]
                w2 = state[13]
                w3 = state[14]
                w4 = state[15]
                w5 = state[16]
                w6 = state[17]
                a1 = walk_m.addVar(lb=-1, ub=1, name="a1", vtype=GRB.CONTINUOUS)
                a2 = walk_m.addVar(lb=-1, ub=1, name="a2", vtype=GRB.CONTINUOUS)
                a3 = walk_m.addVar(lb=-1, ub=1, name="a3", vtype=GRB.CONTINUOUS)
                a4 = walk_m.addVar(lb=-1, ub=1, name="a4", vtype=GRB.CONTINUOUS)
                a5 = walk_m.addVar(lb=-1, ub=1, name="a5", vtype=GRB.CONTINUOUS)
                a6 = walk_m.addVar(lb=-1, ub=1, name="a6", vtype=GRB.CONTINUOUS)
                v = walk_m.addVar(ub=self.constraint_ub, name="v")
                u1 = walk_m.addVar(lb=-gp.GRB.INFINITY, name="u1")
                u2 = walk_m.addVar(lb=-gp.GRB.INFINITY, name="u2")
                u3 = walk_m.addVar(lb=-gp.GRB.INFINITY, name="u3")
                u4 = walk_m.addVar(lb=-gp.GRB.INFINITY, name="u4")
                u5 = walk_m.addVar(lb=-gp.GRB.INFINITY, name="u5")
                u6 = walk_m.addVar(lb=-gp.GRB.INFINITY, name="u6")
                abs_u1 = walk_m.addVar(ub=10, name="abs_u1")
                abs_u2 = walk_m.addVar(ub=10, name="abs_u2")
                abs_u3 = walk_m.addVar(ub=10, name="abs_u3")
                abs_u4 = walk_m.addVar(ub=10, name="abs_u4")
                abs_u5 = walk_m.addVar(ub=10, name="abs_u5")
                abs_u6 = walk_m.addVar(ub=10, name="abs_u6")
                obj = (a1-neta1)**2 + (a2-neta2)**2 + (a3-neta3)**2 + (a4-neta4)**2 + (a5-neta5)**2 + (a6-neta6)**2
                walk_m.setObjective(obj, GRB.MINIMIZE)

                walk_m.addConstr(u1 == a1*w1)
                walk_m.addConstr(u2 == a2*w2)
                walk_m.addConstr(u3 == a3*w3)
                walk_m.addConstr(u4 == a4*w4)
                walk_m.addConstr(u5 == a5*w5)
                walk_m.addConstr(u6 == a6*w6)
                walk_m.addConstr(abs_u1 == (gp.max_(u1,0)))
                walk_m.addConstr(abs_u2 == (gp.max_(u2,0)))
                walk_m.addConstr(abs_u3 == (gp.max_(u3,0)))
                walk_m.addConstr(abs_u4 == (gp.max_(u4,0)))
                walk_m.addConstr(abs_u5 == (gp.max_(u5,0)))
                walk_m.addConstr(abs_u6 == (gp.max_(u6,0)))
                walk_m.addConstr((abs_u1 + abs_u2 + abs_u3 + abs_u4 + abs_u5 + abs_u6) == v)

                walk_m.optimize()
                return torch.tensor([a1.X, a2.X, a3.X, a4.X, a5.X, a6.X])

    @torch.no_grad()

    def predict_next_obs_mean(self, obs_arr, samples):
        npart = 20
        ac_seqs = torch.from_numpy(samples[:,0:self.action_dim]).float().to(TORCH_DEVICE)

        # Reshape ac_seqs so that it's amenable to parallel compute
        # Before, ac seqs has dimension (400, 25) which are pop size and sol dim coming from CEM
        ac_seqs = ac_seqs.view(-1, 1, self.action_dim)
        #  After, ac seqs has dimension (400, 25, 1)

        transposed = ac_seqs.transpose(0, 1)
        # Then, (25, 400, 1)

        expanded = transposed[:, :, None]
        # Then, (25, 400, 1, 1)

        tiled = expanded.expand(-1, -1, npart, -1)
        # Then, (25, 400, 20, 1)

        ac_seqs = tiled.contiguous().view(1, -1, self.action_dim)
        # Then, (25, 8000, 1)

        # Expand current observation
        cur_obs_ = torch.from_numpy(obs_arr[:,0:self.obs_dim]).float().to(TORCH_DEVICE)
        # cur_obs_ = cur_obs_[None]
        cur_obs_ = cur_obs_.repeat_interleave(npart,dim=0)
        for t in range(1):
            cur_acs = ac_seqs[t]
            next_obs = self._predict_next_obs(cur_obs_, cur_acs)
            cur_obs_ = self.obs_postproc2(next_obs)
        cur_obs_ = cur_obs_.cpu().numpy().reshape(obs_arr.shape[0],npart,self.obs_dim).mean(axis=1)
        
        return cur_obs_

    def state_dependent_rejection_sampling(self, X, mean, constrained_var, cur_obs, filter):
        samples = np.zeros((self.popsize, self.sol_dim))
        act_idx = 0
        num_sub_vectors = 1
        vec_dim = num_sub_vectors * self.action_dim
        obs_arr = np.zeros((self.popsize, self.obs_dim * (self.sol_dim//self.action_dim)))
        obs_arr[:,0:self.obs_dim] = np.tile(np.expand_dims(cur_obs, axis=0), (self.popsize, 1))
        while act_idx<self.sol_dim:
            obs_idx = (act_idx//self.action_dim)*self.obs_dim
            iterat=0
            feasibility_mask = np.full(self.popsize, True, dtype=bool)
            while np.any(feasibility_mask):
                iterat+=1
                batch_samples = X.rvs(size=[np.sum(feasibility_mask), vec_dim]) * np.sqrt(constrained_var[act_idx:act_idx+vec_dim]) + mean[act_idx:act_idx+vec_dim]
                if act_idx<self.action_dim*self.arc_step:
                    batch_samples = (1 - self.beta) * batch_samples + self.beta * self.samples_expert_actions[act_idx:act_idx+self.action_dim]
                if iterat>10:
                    feasibility = np.full(np.sum(feasibility_mask), True)
                    indices = np.where(feasibility_mask)[0]
                    samples[indices[feasibility],act_idx:act_idx+vec_dim] = batch_samples[feasibility]
                    feasibility_mask[indices[feasibility]] = False
                else:
                    sub_vectors = batch_samples.reshape(np.sum(feasibility_mask), vec_dim)
                    feasibility = filter(sub_vectors, obs_arr[feasibility_mask,obs_idx:obs_idx+self.obs_dim])
                    indices = np.where(feasibility_mask)[0]
                    samples[indices[feasibility],act_idx:act_idx+vec_dim] = batch_samples[feasibility]
                    feasibility_mask[indices[feasibility]] = False
                    
            if act_idx<self.sol_dim-self.action_dim:
                predict_mean = self.predict_next_obs_mean(obs_arr[:,obs_idx:obs_idx+self.obs_dim],samples[:,act_idx:act_idx+vec_dim])
                obs_arr[:,obs_idx+self.obs_dim:obs_idx+self.obs_dim*2] = predict_mean

            act_idx += vec_dim
        return samples
    
    def box_filter(self, sub_vectors):
        valid_samples_mask = np.logical_and(np.all(sub_vectors <= self.ub[0], axis=2), np.all(sub_vectors >= self.lb[0], axis=2))
        valid_samples_mask = np.all(valid_samples_mask, axis=1)
        return valid_samples_mask
    
    def l2_filter(self, sub_vectors):
        norms = np.linalg.norm(sub_vectors, axis=2)
        valid_samples_mask = np.all(norms <= self.constraint_ub**0.5, axis=1)
        return valid_samples_mask

    def H_M_filter(self, sub_vectors, cur_obs):
        m = cur_obs[:,8:11]*sub_vectors
        m[m<0] = 0
        valid_samples_mask = np.sum(m,axis=1) <= self.constraint_ub
        return valid_samples_mask

    def HC_O_filter(self, sub_vectors, cur_obs):
        valid_samples_mask = np.sum(np.abs(cur_obs[:,-6:]*sub_vectors),axis=1) <= self.constraint_ub
        return valid_samples_mask

    def W_M_filter(self, sub_vectors, cur_obs):
        m = cur_obs[:,-6:]*sub_vectors
        m[m<0] = 0
        valid_samples_mask = np.sum(m,axis=1) <= self.constraint_ub
        return valid_samples_mask

    def rejection_sampling(self, X, mean, constrained_var, filter, batch_size=600):
        samples = np.zeros((self.popsize, self.sol_dim))
        act_idx = 0
        num_sub_vectors = 1
        vec_dim = num_sub_vectors*self.action_dim
        while act_idx<self.sol_dim:
            sub_samples = []
            while len(sub_samples)<self.popsize:
                
                batch_samples = X.rvs(size=[batch_size, vec_dim]) * np.sqrt(constrained_var[act_idx:act_idx+vec_dim]) + mean[act_idx:act_idx+vec_dim]
                if act_idx<self.action_dim*self.arc_step:
                    batch_samples = (1 - self.beta) * batch_samples + self.beta * self.samples_expert_actions[act_idx:act_idx+self.action_dim]
                sub_vectors = batch_samples.reshape(batch_size, num_sub_vectors, self.action_dim)
                valid_samples_mask = filter(sub_vectors)
                valid_samples = batch_samples[valid_samples_mask]
                sub_samples.extend(valid_samples)
                sub_samples = sub_samples[:self.popsize]
            samples[:,act_idx:act_idx+vec_dim] = np.array(sub_samples)
            act_idx += vec_dim
        return samples

    def obtain_solution(self, init_mean, init_var,cur_obs,expert_actions=None):
        """Optimizes the cost function using the provided initial candidate distribution

        Arguments:
            init_mean (np.ndarray): The mean of the initial candidate distribution.
            init_var (np.ndarray): The variance of the initial candidate distribution.
        """
        
        mean, var, t = init_mean, init_var, 0
        self.samples_expert_actions = np.zeros_like(mean)
        expert_actions_flatten = expert_actions.cpu().flatten().numpy().astype(np.float32)
        self.samples_expert_actions[:min(self.action_dim*self.arc_step,expert_actions_flatten.shape[0])] = expert_actions_flatten[:self.action_dim*self.arc_step]
        if self.constraint_type=='box':
            self.samples_expert_actions = np.clip(self.samples_expert_actions, a_min=self.lb, a_max=self.ub)
        elif self.constraint_type=='l2':
            for i in range(min(self.arc_step, expert_actions_flatten.shape[0]//self.action_dim)):
                self.samples_expert_actions[i*self.action_dim:(i+1)*self.action_dim] = self.L2_Projection(self.samples_expert_actions[i*self.action_dim:(i+1)*self.action_dim]).numpy().astype(np.float32)
        elif self.constraint_type=='HC_O':
            cur_obs_tmp = cur_obs
            for i in range(min(self.arc_step, expert_actions_flatten.shape[0]//self.action_dim)):
                self.samples_expert_actions[i*self.action_dim:(i+1)*self.action_dim] = self.HC_O_Projection(self.samples_expert_actions[i*self.action_dim:(i+1)*self.action_dim],cur_obs_tmp).numpy().astype(np.float32)
                cur_obs_tmp = self.predict_next_obs_mean(np.expand_dims(cur_obs_tmp, axis=0), np.expand_dims(self.samples_expert_actions[i*self.action_dim:(i+1)*self.action_dim], axis=0))
                cur_obs_tmp = cur_obs_tmp[0]
        elif self.constraint_type == 'H_M':
            cur_obs_tmp = cur_obs
            for i in range(min(self.arc_step,expert_actions_flatten.shape[0]//self.action_dim)):
                self.samples_expert_actions[i*self.action_dim:(i+1)*self.action_dim] = self.H_M_Projection(self.samples_expert_actions[i*self.action_dim:(i+1)*self.action_dim],cur_obs_tmp).numpy().astype(np.float32)
                cur_obs_tmp = self.predict_next_obs_mean(np.expand_dims(cur_obs_tmp, axis=0), np.expand_dims(self.samples_expert_actions[i*self.action_dim:(i+1)*self.action_dim], axis=0))
                cur_obs_tmp = cur_obs_tmp[0]
        elif self.constraint_type == 'W_M':
            cur_obs_tmp = cur_obs
            for i in range(min(self.arc_step,expert_actions_flatten.shape[0]//self.action_dim)):
                self.samples_expert_actions[i*self.action_dim:(i+1)*self.action_dim] = self.W_M_Projection(self.samples_expert_actions[i*self.action_dim:(i+1)*self.action_dim],cur_obs_tmp).numpy().astype(np.float32)
                cur_obs_tmp = self.predict_next_obs_mean(np.expand_dims(cur_obs_tmp, axis=0), np.expand_dims(self.samples_expert_actions[i*self.action_dim:(i+1)*self.action_dim], axis=0))
                cur_obs_tmp = cur_obs_tmp[0]
        if self.constraint_type=='no':
            X = stats.truncnorm(-2, 2, loc=np.zeros_like(mean), scale=np.ones_like(var))
        else:
            X = stats.truncnorm(-2, 2, loc=np.zeros(self.action_dim), scale=np.ones(self.action_dim))
        while (t < self.max_iters) and np.max(var) > self.epsilon:
            lb_dist, ub_dist = mean - self.lb, self.ub - mean
            constrained_var = np.minimum(np.minimum(np.square(lb_dist / 2), np.square(ub_dist / 2)), var)
            
            if self.constraint_type=='no':
                samples = X.rvs(size=[self.popsize, self.sol_dim]) * np.sqrt(constrained_var) + mean
                samples[:self.action_dim*self.arc_step] = (1 - self.beta) * samples[:self.action_dim*self.arc_step] + self.beta * self.samples_expert_actions[:self.action_dim*self.arc_step]
            elif self.constraint_type=='box':
                samples = self.rejection_sampling(X,mean,constrained_var,filter=self.box_filter)
            elif self.constraint_type=='l2':
                samples = self.rejection_sampling(X,mean,constrained_var,filter=self.l2_filter)
            elif self.constraint_type=='HC_O':
                samples = self.state_dependent_rejection_sampling(X,mean,constrained_var,cur_obs,filter=self.HC_O_filter)
            elif self.constraint_type == 'H_M':
                samples = self.state_dependent_rejection_sampling(X,mean,constrained_var,cur_obs,filter=self.H_M_filter)
            elif self.constraint_type == 'W_M':
                samples = self.state_dependent_rejection_sampling(X,mean,constrained_var,cur_obs,filter=self.W_M_filter)
            
            samples = samples.astype(np.float32)
            costs = self.cost_function(samples)
            elites = samples[np.argsort(costs)][:self.num_elites]
            new_mean = np.mean(elites, axis=0)
            new_var = np.var(elites, axis=0)

            mean = self.alpha * mean + (1 - self.alpha) * new_mean
            var = self.alpha * var + (1 - self.alpha) * new_var

            t += 1
        
        return mean, constrained_var
