import glob
from tqdm import tqdm
import os, sys
current_file_path = os.path.abspath(__file__)
sys.path.append(os.path.dirname((current_file_path))+'/../')
import warnings
warnings.filterwarnings("ignore")
from adbench.run import RunPipeline
from pipeline import RunPipeline_unsup

from adbench.myutils import Utils

from fit import fit
from model import LitAutoEncoder
from gof import GOFScorer

from sklearn.cluster import MiniBatchKMeans
from sklearn.mixture import GaussianMixture
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import silhouette_score, davies_bouldin_score 

import torch
import numpy as np
import pandas as pd

from torch import nn, optim
from torch.utils.data import DataLoader, Dataset

import pytorch_lightning as pl

from utils import CustomDataset
from utils import RowScaler


class Galaxy():
    meta_data = {
        'preprocess': "z-score", 
        'space': 'deep', 
        'clustering': 'Galaxy-S', 
    }
    
    def __init__(self, 
                 seed:int=42, 
                 model_name:str='Galaxy', 
                 k:int=10,
                 l=0.01,
                 hidden_dim=128,
                 training_iters=3,
                 score_type='vector',
                 preprocess='z-score',
                 pretrain=True,
                 abla_clust=None,
                 ):
        self.seed = seed
        self.utils = Utils()
        self.model_name = model_name
        self.hidden_dim = hidden_dim
        self.l = l
        self.training_iters = training_iters
        self.score_type = score_type
        self.meta_data['preprocess'] = preprocess
        self.pretrain = pretrain
        self.abla_clust = abla_clust
        
        self.k = k
        self.scorer = GOFScorer()
        
        if self.meta_data['space'] == 'deep':
            self.trainer = pl.Trainer(
                max_epochs=200, 
                accelerator='gpu', 
                devices=1, 
                enable_checkpointing=False,
                enable_progress_bar=False)
        if self.meta_data['preprocess'] == 'z-score':
            self.scaler = StandardScaler()
        elif self.meta_data['preprocess'] == 'row-norm':
            self.scaler = RowScaler()
        
    def fit(self, X_train, y_train=None):
        if self.meta_data['preprocess'] is not None:
            self.scaler = self.scaler.fit(X_train)
            X_train = self.scaler.transform(X_train)
        
        self.rep_model = LitAutoEncoder(X_train.shape[1], X_train.shape[1], self.hidden_dim, self.hidden_dim, learning_rate=3e-3)
        
        if self.meta_data['space'] == 'deep':
            train_loader = DataLoader(CustomDataset(X_train), batch_size=1024, shuffle=True)
            
            # pretrain
            if self.pretrain:
                self.trainer.fit(self.rep_model, train_loader)
                
                print('pretrain done !!!!!')
            
            X = self.get_embedding(X_train).detach().numpy()
        elif self.meta_data['space'] == 'input':
            X = X_train

        if self.meta_data['clustering'] == 'kmeans':
            self.clustering = MiniBatchKMeans(
                n_clusters=self.k, 
                random_state=self.seed, 
                n_init='auto',
            ).fit(X)
        elif self.meta_data['clustering'] == 'GMM':
            self.clustering = GaussianMixture(
                n_components=self.k,
                random_state=self.seed,
            ).fit(X)
        elif self.meta_data['clustering'] == 'Galaxy-S':
            from em import GalaxyBase_SMM
            self.clustering = GalaxyBase_SMM(
                self.rep_model,
                k=self.k,
                l=self.l,
                seed=self.seed, 
                iters=self.training_iters,
                version='scalar',
                abla_clust=self.abla_clust,
                score_type=self.score_type,
                ).fit(X_train)
        # elif self.meta_data['clustering'] == 'Galaxy-V':
        #     from em import GalaxyBase_SMM
        #     self.clustering = GalaxyBase_SMM(
        #         self.rep_model,
        #         k=self.k,
        #         l=self.l,
        #         seed=self.seed, 
        #         iters=self.training_iters,
        #         version='vector',
        #         ).fit(X_train)
        
        return self
    
    def get_embedding(self, X, device='cpu'):
        if isinstance(X, np.ndarray):
            X = torch.from_numpy(X)
        self.rep_model.to(device)
        embed = self.rep_model(X.to(torch.float32))[0]
        return embed

    def predict_score(self, X):
        if self.meta_data['preprocess'] is not None:
            X = self.scaler.transform(X)
        
        if self.meta_data['space'] == 'deep':
            X = self.get_embedding(X).detach().numpy()
        else:
            pass
        
        if self.meta_data['clustering'] in ['GMM']:
            c = self.clustering.means_
        elif self.meta_data['clustering'] == 'kmeans':
            c = self.clustering.cluster_centers_
        elif self.meta_data['clustering'] in ["Galaxy-S", "Galaxy-V"]:
            c = self.clustering.means
            
        if not isinstance(c, torch.Tensor):
            c = torch.from_numpy(c)
        c = c.cpu().detach()
        
        X = torch.from_numpy(X)
        
        assert not torch.isnan(X).any()
        assert not torch.isnan(c).any()
        
        if self.score_type in ['vector', 'scalar']:
            score = self.scorer.get_score(X, c, score_type=self.score_type)
        
        return score
    

if __name__ == "__main__":
    pl.seed_everything(42)
    
    ONLY_TABULAR = True
    DEBUG = True
    suffix = f"latest"
    
    if ONLY_TABULAR:
        pipeline = RunPipeline_unsup(suffix='ADBench', parallel='unsupervise', realistic_synthetic_mode=None, noise_type=None)
        results = pipeline.run_tabular(clf=Galaxy)

        df = pd.DataFrame(results)
        df.to_csv(f'results/galaxy_tabular-{suffix}.csv', index=None)
        exit()

    
    if DEBUG:
        X_train = np.random.randn(1000, 20)
        y_train = np.random.choice([0, 1], 1000)
        X_test = np.random.randn(100, 20)
        
        model = Galaxy(seed=42, model_name='Galaxy') 
        model.fit(X_train, None) 
        score = model.predict_score(X_test) 
        print(score)
        
        # customized model on customized dataset
        dataset = {}
        dataset['X'] = np.random.randn(1000, 20)
        dataset['y'] = np.random.choice([0, 1], 1000)
        
        dataset = np.load('data/Classical/38_thyroid.npz')
        pipeline = RunPipeline(suffix='ADBench', parallel='unsupervise', realistic_synthetic_mode=None, noise_type=None)
        results = pipeline.run(dataset=dataset, clf=Galaxy)
        print(results)
        
        dataset = np.load('data/Classical/2_annthyroid.npz')
        pipeline = RunPipeline(suffix='ADBench', parallel='unsupervise', realistic_synthetic_mode=None, noise_type=None)
        results = pipeline.run(dataset=dataset, clf=Galaxy)
        print(results)
    else:
        pipeline = RunPipeline_unsup(suffix='ADBench', parallel='unsupervise', realistic_synthetic_mode=None, noise_type=None)
        results = pipeline.run(clf=Galaxy)
        print(results)
        
        df = pd.DataFrame(results)
        df.to_csv('results/galaxy_latest.csv', index=None)
       
    
        
