import numpy as np
import pandas as pd
import re
import scipy
import copy

import warnings
from random import random
from math import log, ceil
from time import time, ctime
import csv
import json
import statsmodels.api as sm
import scipy.stats as sps
import warnings

import ConfigSpace as CS

def min_max_scaling(values):
	min_val = min(values)
	max_val = max(values)
	range_val = max_val - min_val
	scaled_values = [(x - min_val) / range_val if range_val != 0 else 0 for x in values]
	return scaled_values


class SS:
	def __init__(self, get_params_function, try_params_function, max_iter=81,
				 eta=3, skip_first = 0):
		
		self.get_params = get_params_function
		self.try_params = try_params_function

		# store precomputed probs for the categorical parameters
		self.cat_probs = []
	
		self.max_iter = max_iter  	# maximum iterations per configuration
		self.eta = eta			# defines configuration downsampling rate (default = 3)
		self.skip_first = skip_first
		
		self.logeta = lambda x: log( x ) / log( self.eta )
		self.s_max = int( self.logeta( self.max_iter ))
		self.B = ( self.s_max + 1 ) * self.max_iter

		self.samples = np.array([])

		self.counter = 0
		self.fixed_config_dict = dict()
		
	def __str__(self):
		return f"SS_Max_iter_{self.max_iter}_eta_{self.eta}"
	
	def run_fixed_configs(self, criteria = 'valid_accuracy', direction = None):
		# clear results
		results = []
		final_results = []
		
		# dealing with special criteria
		if criteria.startswith('wgh'):
			# Find all matches of integers in criteria
			pattern = r'\d+'
			matches = re.findall(pattern, criteria)
			# Extract the first two numbers
			if len(matches) >= 2:
				wgh1 = float(matches[0]) * 0.1
				wgh2 = float(matches[1]) * 0.1
				print(f"wgh1 = {wgh1}, wgh2 = {wgh2}")
			else:
				raise ValueError("Not enough numbers found in criteria.")
		
		total_budget_used = 0
		for s in reversed(range(self.s_max + 1)):

			# initial number of configurations
			n = int( ceil( self.B / self.max_iter / ( s + 1 ) * self.eta ** s ))
			
			# initial number of iterations per config
			r = self.max_iter * self.eta ** ( -s )		
			
			ss_rsts = [[] for i in range(n)]
			ss_criterias = [[] for i in range(n)]
			ss_epochs = [0 for i in range(n)]
			most_epochs = r * self.eta ** (self.skip_first)
			
			## sub sampling arrange, first epoch
			id2order = [0] * n
			order2id = []
			counter = 0
			for j in range(n):
				id2order[j] = counter
				order2id.append(j)
				counter += 1
				config = self.get_sample(s, n, j)
				b = int(r * self.eta ** (self.skip_first))
				for plus in range(b):
					rst = self.try_params(b, config, criteria)
					ss_rsts[j].append(rst)
					ss_criterias[j].append(rst[criteria])
				ss_epochs[j] = b

			for i in range(self.skip_first + 1, s + 1):
				# Run each of the n configs for <iterations> 
				# and keep best (n_configs / eta) configurations
				
				n_configs = n * self.eta ** ( -i + self.skip_first )
				n_iterations = r * self.eta ** ( i )
				
				print( "\n*** {} configurations x {:.1f} iterations each".format( 
					n_configs, n_iterations ))
				
				idxc = np.argmax(ss_epochs)
				most_epochs = ss_epochs[int(idxc)]
				preference_set = []
				
				criterias = []
				for j in range(int(n)):
					
					this_epochs = ss_epochs[j]
					if this_epochs == 0 :
						continue
					
					if this_epochs < np.sqrt(np.log(most_epochs)): ##
						preference_set.append(j)
						continue
					
					this_average = np.mean(np.array(ss_criterias[j]))

					if this_epochs >= most_epochs:
						continue
					
					flag = False
					for k in range(most_epochs - this_epochs + 1):
						if direction == 'Min' and this_average < np.mean(ss_criterias[idxc][k : k + this_epochs]):   #### should add ss_rsts for each config
							flag = True
							break
						elif direction == 'Max' and this_average > np.mean(ss_criterias[idxc][k : k + this_epochs]):   #### should add ss_rsts for each config
							flag = True
							break
						
					if flag == True:
						preference_set.append(j)
				
				if len(preference_set) == 0:
					config = self.get_sample(s, n, idxc)
					for plus in range(ss_epochs[idxc]+1, int(n_iterations)):
						rst = self.try_params(plus, config, criteria)
						ss_rsts[idxc].append(rst)
						ss_criterias[idxc].append(rst[criteria])
					ss_epochs[idxc] += int(n_iterations)
					result = ss_rsts[idxc][-1]
					result['s'] = s
					result['params'] = config
					result['n_iteration'] = n_iterations
					results.append(result)
					criterias.append(result[criteria])
					# last round of successive halving
					if i == s:
						final_results.append(result)
				else:
					for idx in preference_set:
						config = self.get_sample(s, n, idx)
						plus_epochs = int((n_iterations - ss_epochs[idx]) / len(preference_set))
						for plus in range(ss_epochs[idx]+1, ss_epochs[idx] + plus_epochs):
							rst = self.try_params(plus, config, criteria)
							ss_rsts[idx].append(rst)
							ss_criterias[idx].append(rst[criteria])
						ss_epochs[idx] += plus_epochs
						result = ss_rsts[idx][-1]
						result['s'] = s
						result['params'] = config
						result['n_iteration'] = n_iterations
						results.append(result)
						criterias.append(result[criteria])
						# last round of successive halving
						if i == s:
							final_results.append(result)

		# rank final result
		if direction == 'Max':	# maximum
			ranked = sorted(final_results, key=lambda x: x[criteria], reverse=True)
		elif direction == 'Min':
			ranked = sorted(final_results, key=lambda x: x[criteria])
		else:
			raise ValueError(f"Invalid direction '{direction}'.")
		
		print(" ****** the best one ***** ")
		print(ranked[0])
		# append the best one to the last of rst
		results.append(ranked[0])

		return results
	

	def get_sample(self, s, n, j):
		if f's_{s}' in self.fixed_config_dict:
			return self.fixed_config_dict[f's_{s}'][j]
		else:
			T = [ self.get_params() for i in range( n )] 
			self.fixed_config_dict[f's_{s}'] = T
			return T[0]
	
	def get_fixed_config_dict(self, config_space):
		if not self.fixed_config_dict:
			raise ValueError("config_dict is empty.")
		serialized_config_dict = dict()
		for s in reversed( range(self.skip_first, self.s_max + 1 )):
			T = []
			for config in self.fixed_config_dict[f's_{s}']:
				T.append(config.get_dictionary())
			serialized_config_dict[f's_{s}'] = T
		return serialized_config_dict
	
	def load_fixed_config_dict(self, file_path, config_space):
		with open(file_path, "r") as json_file:
			loaded_configuration_dict = json.load(json_file)
		self.fixed_config_dict = dict()
		for s in reversed( range(self.skip_first, self.s_max + 1 )):
			T = []
			for config in loaded_configuration_dict[f's_{s}']:
				T.append(CS.Configuration(config_space, values=config))
			self.fixed_config_dict[f's_{s}'] = T
			
	def get_fixed_config_dict_lcbench(self):
		if not self.fixed_config_dict:
			raise ValueError("config_dict is empty.")

		return self.fixed_config_dict
	
	def load_fixed_config_dict_lcbench(self, file_path):
		self.fixed_config_dict = None
		with open(file_path, "r") as json_file:
			self.fixed_config_dict = json.load(json_file)
		
		if self.fixed_config_dict == None:
			raise ValueError("Error in loading configuration dictionary.")
		
	def record_to_csv(self, results, record_file='./record.csv'):
		df = pd.DataFrame(results)
		df.to_csv(record_file, index=False)

