import os
import time
from abc import ABC, abstractmethod
from collections.abc import Iterable

import joblib
from sklearn.decomposition import PCA


class AbstractModel():
    def __init__(self):
        self.initial = []
        self.final = []
        

class Reduction(ABC):
    @abstractmethod
    def do_reduction(self, data):
        pass
    

class PCA_R(Reduction):
    def __init__(self, top_components):
        self.top_components = top_components
        self.pca = None
        self.pca_min = None
        self.pca_max = None
        self.explained_variance_ratio = None
        
    def create_pca(self, all_observations):
        assert(len(all_observations) > 0)
        if self.top_components >= all_observations[0].shape[-1]:
            self.pca = None
            self.pca_min = np.min(all_observations, axis=0)
            self.pca_max = np.max(all_observations, axis=0)
            return all_observations, self.pca_min, self.pca_max
        else:
            print("build a PCA model...")
            print(">>>original:{}".format(all_observations.shape))
            start_time = time.time()
            self.pca = PCA(n_components=self.top_components)
            self.pca.fit(all_observations)
            pca_data = self.pca.transform(all_observations)
            self.pca_min = np.around(np.min(pca_data,axis=0),4)
            self.pca_max = np.around(np.max(pca_data,axis=0),4)
            self.explained_variance_ratio = self.pca.explained_variance_ratio_
            #print(">>>pca_data",pca_data.shape)
            #print(">>>explained_variance_ratio",np.sum(self.pca.explained_variance_ratio_))
            #print("PCA data min:",self.pca_min)
            #print("PCA data max:",self.pca_max)
            #print(">>>it takes {} seconds.".format(time.time()-start_time))
            return pca_data, self.pca_min, self.pca_max
        
    def do_reduction(self, data):
        if self.pca is None:
            return data
            assert False,"please create a pca model based on training data..."
        else:
            return self.pca.transform(data)
        
class Grid(AbstractModel):
    '''
    Multiple DTMCs from a set of sets of traces
    traces: a set of sets of traces
    '''
    def __init__(self, min_val, max_val, grid_num, clipped=True):
        super().__init__()
        self.min = min_val
        self.max = max_val
        self.k = grid_num
        self.dim = max_val.shape[0]
        self.total_states = pow(grid_num,self.dim)
        self.unit = (max_val - min_val) / self.k
        self.clipped = clipped
        
    def state_abstract(self, con_states):

        def deep_flatten(obj):
            is_scalar_array = isinstance(obj, np.ndarray) and obj.ndim == 0
            if isinstance(obj, Iterable) and not is_scalar_array:
                for item in obj:
                    yield from deep_flatten(item)
            else:
                yield obj

        def stringify_con_state(item):
            if hasattr(item, "__dict__"):
                item = item.__dict__
            if isinstance(item, dict):
                flat = deep_flatten(item.values())
            else:
                flat = item  # assumed to be flat iterable
                if flat is None:
                    flat = [None]
            return "_".join(map(str, flat))

        if self.dim == 0:
            return list(map(str, con_states))

        return [stringify_con_state(item) for item in con_states]

        con_states = con_states
        lower_bound = self.min
        upper_bound = self.max
        unit = (upper_bound - lower_bound)/self.k
        abs_states = np.zeros(con_states.shape[0],dtype=np.int8)
        
        #print(lower_bound)
        #print(upper_bound)
        indices = np.where(unit == 0)[0]
        unit[indices] = 1
        #print('unit:\t', unit)

        tmp = ((con_states-self.min)/unit).astype(int)
        if self.clipped:
            tmp = np.clip(tmp, 0, self.k-1)
            
        dims = tmp.shape[1]
        for i in range(dims):
            abs_states = abs_states + tmp[:,i]*pow(self.k, i)
#         abs_states = np.expand_dims(abs_states,axis=-1)
        abs_states = [str(item) for item in abs_states]
        return abs_states
    
    def extract_abs_trace(self, dones, abs_states, abs_rewards, abs_values):

        end_idx = np.where(np.abs(dones)==1)[0]
        all_traces = []
        all_rewards = []
        all_values = []

        start = 0
        for cur_end in end_idx:
            all_rewards.append(abs_rewards[start : cur_end + 1])
            all_values.append(abs_values[start : cur_end + 1])
            all_traces.append(abs_states[start : cur_end + 1])
            start = cur_end + 1
           
        return all_traces,all_rewards,all_values

    
    
    
    def profiling_all_traces(self, all_traces):
        state_dic = defaultdict(list) # {state:[fail_stat, success_stat]}
        edge_dic = defaultdict(list)
        for cur_trace in all_traces:
            isSuccess =  cur_trace[-1]=='S'
            for idx in range(len(cur_trace)-1):
                cur_state = cur_trace[idx]
                next_state = cur_trace[idx+1]
                cur_edge = (cur_state,next_state)
                if cur_state not in state_dic:
                    state_dic[cur_state] = [0,0]
                state_dic[cur_state][isSuccess] += 1
                if cur_edge not in edge_dic:
                    edge_dic[cur_edge] = [0,0]
                edge_dic[cur_edge][isSuccess] += 1
        return state_dic,edge_dic

def traj_stat_analysis(trajectory_dic, keep=2):

    trajectory_dic = joblib.load(trajectory_dic)
    #all_observations = trajectory_dic['all_obvs']
    #all_actions = trajectory_dic['all_acts']
    all_dones = trajectory_dic['all_dones'].flatten()
    all_rewards = trajectory_dic['all_rwds'].flatten()
    all_states = trajectory_dic['all_states']
    all_values = trajectory_dic['all_values']
    avg_performace = trajectory_dic['avg_performace']
    last_idx = np.where(np.abs(all_dones)==1)[0][-1] + 1
    all_dones,all_rewards,all_states,all_values = all_dones[:last_idx],all_rewards[:last_idx],all_states[:last_idx],all_values[:last_idx]
    #print(all_observations.shape)
    #print(all_actions.shape)
    #print(all_dones.shape)
    #print(all_rewards.shape)
    #print(all_hidden.shape)
    #print("min obversations:",np.around(np.min(all_observations,axis=0),keep))
    #print("max obversations:",np.around(np.max(all_observations,axis=0),keep))
    #print("min actions:",np.around(np.min(all_actions,axis=0),keep))
    #print("max actions:",np.around(np.max(all_actions,axis=0),keep))

    return (all_dones,all_rewards,all_states,all_values, avg_performace)
    


def pca_analysis(n_components, pcaModelPath, all_observations, save=False):
    if os.path.exists(pcaModelPath):
        pcaModel = joblib.load(pcaModelPath)
        pca_min, pca_max = pcaModel.pca_min, pcaModel.pca_max
        pca_data = pcaModel.do_reduction(all_observations)
        #print("Load saved pca model successfully!")
    else:
        pcaModel = PCA_R(top_components=n_components)
        pca_data, pca_min, pca_max = pcaModel.create_pca(all_observations)
        if save:
            joblib.dump(pcaModel, pcaModelPath)
            #print("Save the pca model to {} successfully".format(pcaModelPath))
    pca_dic = {'pca_data':pca_data, 'pca_min':pca_min, 'pca_max':pca_max }
    return pcaModel, pca_dic


def grid_abs_analysis(pca_dic, all_rewards, all_values, all_dones, grid_num,abs_profiling_file):
    
    pca_min, pca_max, pca_data = pca_dic['pca_min'],pca_dic['pca_max'],pca_dic['pca_data']

    dtmc = DTMC(pca_min, pca_max, grid_num)
    abs_states = dtmc.state_abstract(con_states = pca_data)
    abs_traces, rewards, values= dtmc.extract_abs_trace(all_dones,abs_states,all_rewards, all_values)
    profiling_dic = {
        'dtmc'       : dtmc,
        'abs_states'  : abs_states,
        'abs_traces' : abs_traces,
        'all_rewards':rewards,
        'all_values':values
    }
    joblib.dump(profiling_dic,abs_profiling_file)
    #print("Save profiling results to ",abs_profiling_file)
    
    
    return profiling_dic
    
    
    '''
    #print(pca_min, pca_max, pca_data)
    if os.path.exists(abs_profiling_file):
        profiling_dic = joblib.load(abs_profiling_file)
        # dtmc, abs_states = profiling_dic['dtmc'], profiling_dic['abs_states'] 
        # abs_traces, tracesLen, results = profiling_dic['abs_traces'], profiling_dic['tracesLen'], profiling_dic['results']
        #print("Load profiling results from {} successfully!".format(abs_profiling_file))
    else:
        dtmc = DTMC(pca_min, pca_max, grid_num)
        abs_states = dtmc.state_abstract(con_states = pca_data)
        abs_traces, rewards, values= dtmc.extract_abs_trace(all_dones,abs_states,all_rewards, all_values)
        profiling_dic = {
            'dtmc'       : dtmc,
            'abs_states'  : abs_states,
            'abs_traces' : abs_traces,
            'all_rewards':rewards,
            'all_values':values
        }
        joblib.dump(profiling_dic,abs_profiling_file)
        #print("Save profiling results to ",abs_profiling_file)
    
    
    return profiling_dic
    '''

def con_2_dis(obv_list):
    _, pca_dic = pca_analysis(config.n_components, config.pcaModelPath, obv_list)
    abs_state_list = []
    for data in pca_dic['pca_data']:
        abs_tuple = concrete_to_abstract(pca_dic, None, config.grid_num,np.array([data,data]))
        abs_state_list.append(abs_tuple[0])
    
    return abs_state_list

def concrete_to_abstract(pca_dic, all_rewards, grid_num,concrete_state):
    pca_min, pca_max, pca_data = pca_dic['pca_min'],pca_dic['pca_max'],pca_dic['pca_data']
    dtmc = DTMC(pca_min, pca_max, grid_num)
    abs_states = dtmc.state_abstract(con_states = concrete_state)
    return abs_states
    
    

def calc_coverage(abs_states, grid_num, n_componenct):
    uniq = len(np.unique(np.array(abs_states)))
    total = pow(grid_num, n_componenct)
    #print("The number of unique states:{}".format(uniq))
    #print("The number of total states:{}".format(total))
    #print("Coverage:{:.2f}".format((uniq*1.0)/total))

from collections import defaultdict


def analyze_abstraction(abs_states, abs_traces):
    abs2con = defaultdict(list)
    for idx,key in enumerate(abs_states):
        abs2con[key].append(idx)
    state_dic = defaultdict(list)
    transistion_dic = defaultdict(list)
    counts = []
    for curTrace in abs_traces:
        isSuccess = curTrace[-1]=='S'
        state_set = set()
        counts.append(len(curTrace)-1)
        for idx in range(len(curTrace)-1):
            curState = curTrace[idx]
            nextState = curTrace[idx+1]
            transistion_dic[curState].append(nextState)
            state_set.add(curState)
        state_set.add(curTrace[-2])
        for state in state_set:
            if state not in state_dic:
                state_dic[state] = [0,0]
            state_dic[state][isSuccess] += 1
    avgFreq = np.sum(np.array(counts))/len(state_dic.keys())
    return avgFreq, abs2con, state_dic, transistion_dic
    

def fetchCriticalState(state_dic, lowest, threshold, savepath):
    
    if savepath is not None and os.path.exists(savepath):
        print("Load critical states info from {} successfully!".format(savepath))
        
        critical_dic = joblib.load(savepath)
        good_list, bad_list = critical_dic['good_list'], critical_dic['bad_list']
        print('good state:',len(good_list))
        print('bad state:',len(bad_list))

        return critical_dic['critical'], good_list, bad_list
    
    critical = []
    good_list = []
    bad_list = []
    for key,val in state_dic.items():
        total = val[0]+val[1]
        failRate = (1.0*val[0])/total
        if total < lowest:
            continue
        if failRate<=threshold:
            good_list.append(key)
            critical.append([key, val[0],val[1],failRate])
        if failRate >= 1-threshold:
            bad_list.append(key)
            critical.append([key, val[0],val[1],failRate])
    print("lowest:",lowest)
    print("failRate threshold:",threshold)
    print('good state:',len(good_list))
    print('bad state:',len(bad_list))
    critical = sorted(critical,key=lambda x:x[-1],reverse=True)
    critical_dic = {
    'critical': critical,
    'good_list': good_list,
    'bad_list': bad_list,
    'lowest': lowest,
    'threshold': threshold
    }
    if savepath is not None:
        joblib.dump(critical_dic,savepath)
        print("Save critical states info to ",savepath)
    return critical, good_list, bad_list

import numpy as np
def get_abs_center(concrete_states, abstract_dic, state_list):
    abs_state_centers = []
    for state in state_list:
        conIdx = abstract_dic[state]
        abs_state = np.mean(concrete_states[conIdx],axis=0)
        abs_state_centers.append(abs_state)
    return np.array(abs_state_centers)

def calc_coverage(abs_states,grid_num,n_componenct):
    uniq = len(np.unique(np.array(abs_states)))
    total = pow(grid_num, n_componenct)
    print("The number of unique states:{}".format(uniq))
    print("The number of total states:{}".format(total))
    print("Coverage:{:.2f}".format((uniq*1.0)/total))

def abs_state_eval(abs_traces):
    state_dic = defaultdict(list) # {state:[fail_stat, success_stat]}
    edge_dic = defaultdict(list)
    for cur_trace in abs_traces:
        isSuccess =  cur_trace[-1]=='S'
        state_set = set()
        edge_set = set()
        for idx in range(len(cur_trace)-1):
            cur_state = cur_trace[idx]
            next_state = cur_trace[idx+1]
            cur_edge = (cur_state,next_state)
            if cur_state not in state_dic:
                state_dic[cur_state] = [0,0]
            state_set.add(cur_state)
            if cur_edge not in edge_dic:
                edge_dic[cur_edge] = [0,0]
            edge_set.add(cur_edge)
        
        for state in state_set:
            state_dic[state][isSuccess] += 1
        for edge in edge_set:
            edge_dic[edge][isSuccess] += 1
    
    return state_dic, edge_dic

def fetchStateOrder(abs_traces):
    success_traces = [item for item in abs_traces if item[-1]=='S']
    fail_traces = [item for item in abs_traces if item[-1]=='F']

    goodStateFeqDic = defaultdict(int)
    badStateFeqDic = defaultdict(int)
    for trace in success_traces:
        for cur in set(trace):
            goodStateFeqDic[cur] += 1
    good_list = sorted(goodStateFeqDic.items(),key=lambda x:x[1],reverse=True)

    for trace in fail_traces:
         for cur in set(trace):
            badStateFeqDic[cur] += 1
    bad_list = sorted(badStateFeqDic.items(),key=lambda x:x[1],reverse=True)

    return good_list,bad_list

def abs_state_eval2(abs_traces):
    state_dic = defaultdict(list) # {state:[fail_stat, success_stat]}
    edge_dic = defaultdict(list)
    for cur_trace in abs_traces:
        isSuccess =  cur_trace[-1]=='S'
        for idx in range(len(cur_trace)-1):
            cur_state = cur_trace[idx]
            next_state = cur_trace[idx+1]
            cur_edge = (cur_state,next_state)
            if cur_state not in state_dic:
                state_dic[cur_state] = [0,0]
            state_dic[cur_state][isSuccess] += 1
            if cur_edge not in edge_dic:
                edge_dic[cur_edge] = [0,0]
            edge_dic[cur_edge][isSuccess] += 1
    return state_dic,edge_dic

def fetchStateOrder2(abs_traces):
    success_traces = [item for item in abs_traces if item[-1]=='S']
    fail_traces = [item for item in abs_traces if item[-1]=='F']

    goodStateFeqDic = defaultdict(int)
    badStateFeqDic = defaultdict(int)
    for trace in success_traces:
        for idx in range(len(trace)-1):
            goodStateFeqDic[trace[idx]] += 1
    good_list = sorted(goodStateFeqDic.items(),key=lambda x:x[1],reverse=True)

    for trace in fail_traces:
        for idx in range(len(trace)-1):
            badStateFeqDic[trace[idx]] += 1
    bad_list = sorted(badStateFeqDic.items(),key=lambda x:x[1],reverse=True)

    return good_list,bad_list

