from columns import CTypes as columns
import pandas as pd
import numpy as np

from warnings import catch_warnings
from warnings import simplefilter

import math
import random

import numpy as np
from time import time
from typing import List, Dict, Any

from grammar.mcts.state import State
from grammar.mcts.mcts import MCTS
from grammar.mcts.rollout import average_policy
from grammar.mcts.selection import UpperConfidenceBound
from grammar.mcts.state import TreeState

from grammar.builder.tree import PipelineNameTree
from grammar.scheduler.base import Scheduler

from tqdm.auto import tqdm
tqdm.pandas()


class AverageRanker:
    
    def __init__(self, column_types):

        self.RANKING_GROUP_ID_COLUMN = [col for col, tipo in column_types.items() if tipo == columns.RANKING_GROUP_ID_COLUMN]
        self.RANKING_TARGET_COLUMN = [col for col, tipo in column_types.items() if tipo == columns.RANKING_TARGET_COLUMN]
        self.RANKING_ELEMENT_ID_COLUMN = [col for col, tipo in column_types.items() if tipo == columns.RANKING_ELEMENT_ID_COLUMN]
                
        assert len(self.RANKING_TARGET_COLUMN) == 1
        
        self.RANKING_TARGET_COLUMN = self.RANKING_TARGET_COLUMN[0]
        
        self.RANK_SCORE = 'rank_score'
        self.rank_ranking = None
        self.max_ranking_lenght = None
    
    def fit(self, X, y=None, X_val=None, y_val=None, fit_params={}):
        Xy = pd.concat([X, y], axis=1)
        self.max_ranking_lenght = Xy[self.RANKING_GROUP_ID_COLUMN].value_counts().max()
        Xy[self.RANK_SCORE] = self.get_rank_score(Xy) 
        rank_scores = Xy.pivot(index=self.RANKING_ELEMENT_ID_COLUMN, columns=self.RANKING_GROUP_ID_COLUMN, values=self.RANK_SCORE)
        self.rank_ranking = rank_scores.mean(axis=1).sort_values(ascending=False)
        
    def get_rank_score(self, X):
        return self.max_ranking_lenght - X[self.RANKING_TARGET_COLUMN]

    def predict(self, X):
        arr_tasks =[]
        backup_index = X.index
        for group in X.set_index(self.RANKING_GROUP_ID_COLUMN).index.unique():
            X_group = X[X[self.RANKING_GROUP_ID_COLUMN].T.squeeze()==group].reset_index().set_index(self.RANKING_ELEMENT_ID_COLUMN)
            X_group = X_group.reindex(self.rank_ranking.index.intersection(X_group.index))
            X_group = X_group.set_index(['index'])
            X_group[self.RANKING_TARGET_COLUMN] = range(1, len(X_group)+1)
            arr_tasks.append(X_group)
        return pd.concat(arr_tasks).reindex(backup_index).pop(self.RANKING_TARGET_COLUMN)
    

class AverageScorer(AverageRanker):
    
    def __init__(self, column_types):

        self.RANKING_SCORE_COLUMN = [col for col, tipo in column_types.items() if tipo == columns.RANKING_SCORE_COLUMN]
        
        assert len(self.RANKING_SCORE_COLUMN) == 1
        
        super().__init__(column_types)

    
    def get_rank_score(self, X):
        return X[self.RANKING_SCORE_COLUMN]
    
    
class Random:
    
    def __init__(self, column_types, random_state=42):

        self.RANKING_GROUP_ID_COLUMN = [col for col, tipo in column_types.items() if tipo == columns.RANKING_GROUP_ID_COLUMN]
        self.RANKING_TARGET_COLUMN = [col for col, tipo in column_types.items() if tipo == columns.RANKING_TARGET_COLUMN]
        self.RANKING_ELEMENT_ID_COLUMN = [col for col, tipo in column_types.items() if tipo == columns.RANKING_ELEMENT_ID_COLUMN]
                
        assert len(self.RANKING_TARGET_COLUMN) == 1
        
        self.RANKING_TARGET_COLUMN = self.RANKING_TARGET_COLUMN[0]
        self.random_state = random_state
    
    def fit(self, X, y=None, X_val=None, y_val=None, fit_params={}):
        pass

    def predict(self, X):
        arr_tasks =[]
        backup_index = X.index
        for group in X.set_index(self.RANKING_GROUP_ID_COLUMN).index.unique():
            X_group = X[X[self.RANKING_GROUP_ID_COLUMN].T.squeeze()==group].reset_index(
            ).set_index(self.RANKING_ELEMENT_ID_COLUMN)
            X_group = X_group.sample(frac=1, random_state=self.random_state)
            X_group = X_group.set_index(['index'])
            X_group[self.RANKING_TARGET_COLUMN] = range(1, len(X_group)+1)
            arr_tasks.append(X_group)
        return pd.concat(arr_tasks).reindex(backup_index).pop(self.RANKING_TARGET_COLUMN)
    

class LTR:
    
    def __init__(self, column_types, regressor_model):

        self.RANKING_GROUP_ID_COLUMN = [col for col, tipo in column_types.items() if tipo == columns.RANKING_GROUP_ID_COLUMN]
        self.RANKING_TARGET_COLUMN = [col for col, tipo in column_types.items() if tipo == columns.RANKING_TARGET_COLUMN]
        self.RANKING_ELEMENT_ID_COLUMN = [col for col, tipo in column_types.items() if tipo == columns.RANKING_ELEMENT_ID_COLUMN]
        
        self.DATA_COLUMNS = [col for col, tipo in column_types.items() if tipo == columns.NUMERICAL]
                
        assert len(self.RANKING_TARGET_COLUMN) == 1
        
        self.RANKING_TARGET_COLUMN = self.RANKING_TARGET_COLUMN[0]
        self.INPUT = self.RANKING_ELEMENT_ID_COLUMN + self.DATA_COLUMNS
        
        self.REGRESSOR_SCORE = 'ltr_score'
        self.regressor_model = regressor_model
        self.max_ranking_lenght = None

    def get_regressor_score(self, X):
        return self.max_ranking_lenght - X[self.RANKING_TARGET_COLUMN]
        
    def fit(self, X, y=None, X_val=None, y_val=None, fit_params={}):
        
        Xy = pd.concat([X, y], axis=1)
        self.max_ranking_lenght = Xy[self.RANKING_GROUP_ID_COLUMN].value_counts().max()
        Xy[self.REGRESSOR_SCORE] = self.get_regressor_score(Xy)
        
        if X_val is not None and y_val is not None:
            Xy_val = pd.concat([X_val, y_val], axis=1)
            self.max_ranking_lenght = max(self.max_ranking_lenght, Xy_val[self.RANKING_GROUP_ID_COLUMN].value_counts().max())
            Xy_val[self.REGRESSOR_SCORE] = self.get_regressor_score(Xy_val)
        
        try:
            self.regressor_model.fit(
                Xy[self.INPUT], Xy[self.REGRESSOR_SCORE],
                eval_set=[(Xy_val[self.INPUT], Xy_val[self.REGRESSOR_SCORE])],
                early_stopping_rounds=50, categorical_feature='auto'
            )
        except:
             self.regressor_model.fit(Xy[self.INPUT], Xy[self.REGRESSOR_SCORE], **fit_params)
                
    def predict(self, X_test):
        X = X_test.copy()
        backup_index = X.index

        X[self.REGRESSOR_SCORE] = self.regressor_model.predict(X[self.INPUT])
        X = X.sort_values(self.REGRESSOR_SCORE, ascending=False) 
        
        arr_tasks =[]
        for group in X.set_index(self.RANKING_GROUP_ID_COLUMN).index.unique():
            X_group = X[X[self.RANKING_GROUP_ID_COLUMN].T.squeeze()==group].reset_index().set_index(self.RANKING_ELEMENT_ID_COLUMN)
            X_group = X_group.set_index(['index'])
            X_group[self.RANKING_TARGET_COLUMN] = range(1, len(X_group)+1)
            arr_tasks.append(X_group)
        return pd.concat(arr_tasks).reindex(backup_index).pop(self.RANKING_TARGET_COLUMN)
    

class LTRScorer(LTR):
    
    def __init__(self, column_types, regressor_model):

        self.RANKING_SCORE_COLUMN = [col for col, tipo in column_types.items() if tipo == columns.RANKING_SCORE_COLUMN]
        
        assert len(self.RANKING_SCORE_COLUMN) == 1
        
        super().__init__(column_types, regressor_model)

    
    def get_regressor_score(self, X):
        return X[self.RANKING_SCORE_COLUMN]
    
    
class BayesianScheduler:
    """
    Bayesian Secuential optimization class
    """
    
    def __init__(self, X, surrogate, sample_size=10, randomize=False, pretrain_data = None):
        self.X = X
        #print(self.X.columns)
        self.surrogate = surrogate
        self.sample_size = sample_size
        self.randomize = randomize
        self.pretrain_data = pretrain_data
    
    def __iter__(self):
        return self

    def __next__(self):
        """
        get next element thought an optimization epoch
        """
        if len(self.X)==0:
            raise StopIteration
        x = self.optimize_acquisition(self.X)
        return x

    def acquisition(self, Xsamples):
        """
        probability of improvement acquisition function
        """
        y_pred = self.surrogate_function(Xsamples)
        return y_pred

    def optimize_acquisition(self, X):
        """
        optimize the acquisition function
        """
        #Xsamples = X.sample(self.sample_size) # use this if you have an ordered X
        if self.pretrain_data is not None:
            X = X.sort_values(self.pretrain_data, ascending=False)
        Xsamples = X.iloc[:self.sample_size] # use this if you have a randomized X
        scores = self.acquisition(Xsamples)
        if self.randomize:
            best_i = Xsamples.sample(1)
        else:
            best_i = Xsamples.iloc[[np.argmin(scores)]]
        self.X.drop(best_i.index, inplace=True)
        return best_i
    
    def surrogate_function(self, Xsamples):
        """
        surrogate or approximation for the objective function
        """
        with catch_warnings():
            simplefilter("ignore")
            #print(Xsamples.columns)
            if len(Xsamples)==1:
                 self.surrogate.predict(Xsamples.sample(2, replace=True))[:1]
            else:
                return self.surrogate.predict(Xsamples)
    
    def add_data(self, x, y, x_v=None, y_v=None, **kwargs):
        self.surrogate.fit(x, y, x_v, y_v, **kwargs)
        


class BORanker:
    
    def __init__(self, column_types, regressor_model, component_names, random_state=None):

        self.RANKING_GROUP_ID_COLUMN = [col for col, tipo in column_types.items() if tipo == columns.RANKING_GROUP_ID_COLUMN]
        self.RANKING_TARGET_COLUMN = [col for col, tipo in column_types.items() if tipo == columns.RANKING_TARGET_COLUMN]
        self.RANKING_ELEMENT_ID_COLUMN = [col for col, tipo in column_types.items() if tipo == columns.RANKING_ELEMENT_ID_COLUMN]
        
        self.DATA_COLUMNS = [col for col, tipo in column_types.items() if tipo == columns.NUMERICAL]
                
        assert len(self.RANKING_TARGET_COLUMN) == 1
        
        self.RANKING_TARGET_COLUMN = self.RANKING_TARGET_COLUMN[0]
        self.INPUT = self.RANKING_ELEMENT_ID_COLUMN + self.DATA_COLUMNS
        
        self.regressor_model = regressor_model
        self.max_ranking_lenght = None
        
        self.random_state = random_state
        
        self.PRETRAINED_DATA = 'rank_mean'
        self.SEQUENCE_ORDER = 'sequence_order'
    
    def fit(self, X, y=None, X_val=None, y_val=None, fit_params={}):
        
        self.max_ranking_lenght = X[self.RANKING_GROUP_ID_COLUMN].value_counts().max()
        self.X_train = X.copy()
        self.y_train = y.copy()
        self.regressor_model.fit(X, y)
         
    def predict(self, X_test, y_test):

        arr_tasks = []
        
        for group in tqdm(X_test.set_index(self.RANKING_GROUP_ID_COLUMN).index.unique()):
            
            y_result = self.predict_order(X_test[X_test[self.RANKING_GROUP_ID_COLUMN].T.squeeze()==group],
                                          y_test.loc[X_test[X_test[self.RANKING_GROUP_ID_COLUMN].T.squeeze()==group].index])
            
            arr_tasks.append(y_result)
            
        return pd.concat(arr_tasks).loc[X_test.index][self.SEQUENCE_ORDER]
    
    def predict_order(self, X_test, y_test):
                
        sample_size=10
        randomize=False
        pretrain_rank=False
        pretrain_score=False
        
        X = X_test.copy()
        y = y_test.copy()

        X_result = pd.DataFrame()
        y_result = pd.DataFrame()
        
        np.random.seed(self.random_state)

        scheduler = BayesianScheduler(X, self.regressor_model, sample_size=sample_size, randomize=False, pretrain_data=self.PRETRAINED_DATA)
        
        x_s = self.X_train.copy()
        y_s = self.y_train.copy()
        
        for next_point in scheduler:
            next_value = y.loc[next_point.index]
            x_s = pd.concat([x_s, next_point], ignore_index=False)
            y_s = pd.concat([y_s, next_value], ignore_index=False)
            scheduler.add_data(x_s, y_s) 
            X_result = pd.concat([X_result, next_point], ignore_index=False)
            y_result = pd.concat([y_result, next_value], ignore_index=False)

        y_result[self.SEQUENCE_ORDER] = list(range(1,len(y_result)+1))
        return y_result
    
    
class BOScorer(BORanker):
    
    def __init__(self, column_types, regressor_model, component_names, random_state=None):
        
        self.RANKING_SCORE_COLUMN = [col for col, tipo in column_types.items() if tipo == columns.RANKING_SCORE_COLUMN]
        
        assert len(self.RANKING_SCORE_COLUMN) == 1
        
        super().__init__(column_types, regressor_model, component_names, random_state)
        
        self.PRETRAINED_DATA = 'test_score_mean'

    

class MCTSRanker:
    
    def __init__(self, column_types, regressor_model, component_names, random_state=None, rollout_policy=average_policy):

        self.RANKING_GROUP_ID_COLUMN = [col for col, tipo in column_types.items() if tipo == columns.RANKING_GROUP_ID_COLUMN]
        self.RANKING_TARGET_COLUMN = [col for col, tipo in column_types.items() if tipo == columns.RANKING_TARGET_COLUMN]
        self.RANKING_ELEMENT_ID_COLUMN = [col for col, tipo in column_types.items() if tipo == columns.RANKING_ELEMENT_ID_COLUMN]
        
        self.DATA_COLUMNS = [col for col, tipo in column_types.items() if tipo == columns.NUMERICAL]
                
        assert len(self.RANKING_TARGET_COLUMN) == 1
        
        self.RANKING_TARGET_COLUMN = self.RANKING_TARGET_COLUMN[0]
        self.INPUT = self.RANKING_ELEMENT_ID_COLUMN + self.DATA_COLUMNS
        
        self.rollout_policy = rollout_policy
        
        self.regressor_model = regressor_model
        self.max_ranking_lenght = None
        
        self.component_names = component_names
        self.random_state = random_state
        
        self.PRETRAINED_DATA = 'rank_mean'
        self.SEQUENCE_ORDER = 'sequence_order'
        self.MCTS_REWARD = 'reward'
    
    def fit(self, X, y=None, X_val=None, y_val=None, fit_params={}):
        
        self.max_ranking_lenght = X[self.RANKING_GROUP_ID_COLUMN].value_counts().max()
        self.mean_scores = X[['exploration_step_name', self.PRETRAINED_DATA]].drop_duplicates().set_index('exploration_step_name') 
        self.training_history = X[[
            'model_id','exploration_step_name','rank_mean','test_score_mean']
        ].drop_duplicates().reset_index()
        self.training_history[self.MCTS_REWARD] = self.training_history[self.PRETRAINED_DATA] # Esto es lo que se usa en UCB formula
         
    def predict(self, X_test, y_test):

        arr_tasks = []
        
        for group in tqdm(X_test.set_index(self.RANKING_GROUP_ID_COLUMN).index.unique()):
            
            y_result = self.predict_order(X_test[X_test[self.RANKING_GROUP_ID_COLUMN].T.squeeze()==group],
                                          y_test.loc[X_test[X_test[self.RANKING_GROUP_ID_COLUMN].T.squeeze()==group].index])
            
            arr_tasks.append(y_result)
            
        return pd.concat(arr_tasks).loc[X_test.index][self.SEQUENCE_ORDER]
    
    def scheduler_get_sequences(self, bach_size, training_history, pipeline_metadata):
        
        meta_X = training_history[['exploration_step_name']]
        meta_y = training_history[[self.MCTS_REWARD]]

        component_names = pipeline_metadata['component_names']
        selected_pipelines = pipeline_metadata['selected_pipelines']

        pipeline_name_tree = PipelineNameTree()
        pipeline_name_tree.build_tree(training_history.exploration_step_name.values, training_history[self.PRETRAINED_DATA].values, component_names)
        pipeline_name_tree.compute_average_scores()

        initial_state = TreeState(pipeline_name_tree.root,
                                  key=self.MCTS_REWARD, parent_sequence="")

        selection_policy=UpperConfidenceBound
        selection_policy_params={'C': 1/math.sqrt(2), 'objective':1}
        
        np.random.seed(self.random_state)
        random.seed(self.random_state)

        searcher = MCTS(meta_X, meta_y, initial_state, iteration_limit=bach_size,
                             selection_list = selected_pipelines,
                             selection_policy=selection_policy, selection_policy_params=selection_policy_params,
                             rollout_policy=self.rollout_policy
                       )

        return searcher.search()
    
    def predict_order(self, X_test, y_test):
  
        count_limit=0
        limit = len(X_test)
        batch_size = limit
        
        selected_pipelines = []
        training_history = self.training_history.copy()
        training_history = training_history[training_history.model_id.isin(X_test.model_id.unique())]
        training_history = training_history.sort_values(self.MCTS_REWARD, ascending=False)

        if len(set(X_test.model_id.unique()).difference(set(training_history.model_id.unique())))>0:          
            training_history = self.complete_pre_trained_data(X_test, training_history)
                    
        while count_limit < limit:
            selected_pipeline = self.scheduler_get_sequences(
                batch_size,
                training_history,
                {'component_names': self.component_names,
                 'selected_pipelines' : selected_pipelines}
            )
            
            selected_pipelines.extend(selected_pipeline)
            
            count_limit += batch_size
            
        X = X_test.copy()
        y = y_test.copy()

        X_result = pd.DataFrame()
        y_result = pd.DataFrame()
        
        selection_ranking = [tree_node.sequence for tree_node in selected_pipelines]
        for selected in selection_ranking:
            next_point = X[(X.exploration_step_name==selected)]
            next_value = y.loc[next_point.index]
            X_result = pd.concat([X_result, next_point], ignore_index=False)
            y_result = pd.concat([y_result, next_value], ignore_index=False)    

        y_result[self.SEQUENCE_ORDER] = list(range(1,len(y_result)+1))

        return y_result
    
    def complete_pre_trained_data(self, X_test, training_history):

        unseen_models = set(X_test.model_id.unique()).difference(set(training_history.model_id.unique()))
        unseen_models = X_test[X_test.model_id.isin(unseen_models)]
        
        new_history = unseen_models[['model_id','exploration_step_name']]
        new_history['rank_mean'] = training_history.rank_mean.min()      
        new_history['test_score_mean'] = training_history.test_score_mean.min()
        new_history[self.MCTS_REWARD] = new_history[self.PRETRAINED_DATA]

        return pd.concat([training_history, new_history])


class MCTSScorer(MCTSRanker):
    
    def __init__(self, column_types, regressor_model, component_names, random_state=None, rollout_policy=average_policy):
        
        self.RANKING_SCORE_COLUMN = [col for col, tipo in column_types.items() if tipo == columns.RANKING_SCORE_COLUMN]
        
        assert len(self.RANKING_SCORE_COLUMN) == 1
        
        super().__init__(column_types, regressor_model, component_names, random_state)
        
        self.PRETRAINED_DATA = 'test_score_mean'