import numpy as np
import os
import time
import random
import torch

from .cit4dtg import independence_test, Topo_layers, Topo2DAG
from .cam_prune import cam_pruning
from .eval import backRE

def log(logfile, str_):
    """ Log a string in a file """
    with open(logfile,'a') as f:
        f.write(str_+'\n')
    print(str_)

class DHTCIT(object):
    def __init__(self) -> None:
        self.config = {
                    'name': 'DHTCIT',
                    'exps': 10, 
                    'alpha': 0.01,
                    'cutoff': 0.001, 
                    'detail': True,
                    'device': 'cpu',
                    }

    def set_Configuration(self, config):
        self.config = config

    def run(self, CInd, dataPath, Pind=1, Mind=2, config=None):
        if config is None:
            config = self.config

        self.name = config['name']
        self.exps = config['exps']
        self.alpha = config['alpha']
        self.cutoff = config['cutoff']
        self.detail = config['detail']
        self.device = torch.device(config['device'])

        if Pind == 0:
            IVPanel = True
        else:
            IVPanel = False

        results = []
        mean_time = []
        start_time = time.time()
        savePath = './Result/{}/{}/'.format(dataPath, self.name)
        os.makedirs(os.path.dirname(savePath), exist_ok=True)
        for exp in range(self.exps):
            path = './Data/{}/{}/'.format(dataPath, exp)
            REpath = './Result/{}/{}/{}/'.format(dataPath, self.name, exp)
            logfile = REpath + 'log.txt'
            os.makedirs(os.path.dirname(REpath), exist_ok=True)
            log(logfile, "Run {}-th exp: {} . ".format(exp, path))

            tar_DAG = np.load(path+'adjacency.npy')
            X = torch.load(path+'X.pt')

            mid_time = time.time()
            log(logfile, "Begin {} - {}: {}s . ".format(self.name, exp, mid_time-start_time))

            P = X[Pind]
            M = X[Mind]
            LayerOrd, DTG, alterDTG, pred_DAG = self.single(CInd, P, M, IVPanel, self.alpha, self.cutoff, self.detail)
            
            end_time = time.time()
            run_time = end_time - mid_time
            log(logfile, "End {} - {}: {}s . ".format(self.name, exp, end_time-start_time))
            mean_time.append(run_time)
            
            torch.save(P,     REpath+'P.pt')
            torch.save(M,     REpath+'M.pt')
            np.savetxt(REpath+"tar_DAG.csv", tar_DAG, fmt="%d", delimiter=',')
            np.savetxt(REpath+"pred_DAG.csv", pred_DAG, fmt="%d", delimiter=',')
            np.savetxt(REpath+"alterDTG.csv", alterDTG, fmt="%d", delimiter=',')

            prune_sum  = np.sum((pred_DAG - tar_DAG)==1)
            My_RE = backRE(tar_DAG, pred_DAG)
            results.append(My_RE + [prune_sum])

        results = np.array(results)
        sum_time = np.sum(mean_time)
        np.savetxt(savePath+"results.csv", results, fmt="%.2f", delimiter=',')

        mth_log = './Result/{}/{}/log.txt'.format(dataPath, self.name)

        log(mth_log, f"Method Name: {self.name}. ")
        log(mth_log, f"Sum Time: {sum_time}s. ")
        log(mth_log, "Mean: "+str(results.mean(0)))
        log(mth_log, "Std:  "+str(results.std(0)))

        return results

    def single(self, CInd, P, M, IVPanel=False, alpha=0.01, cutoff=0.001, detail=False):
        n, d = M.shape

        Matrix = independence_test(CInd, P, M, alpha, detail)
        layers = Topo_layers(Matrix, alpha) 
        LayerOrd = Topo2DAG(layers)

        DTG = (Matrix < alpha).astype(int)

        alterDTG = ((LayerOrd == 1) & (DTG == 1)).astype(int)

        if IVPanel:
            pred_DAG = cam_pruning(alterDTG, M, cutoff)
        else:
            pred_DAG = cam_pruning(alterDTG, P, cutoff)

        return LayerOrd, DTG, alterDTG, pred_DAG