from .runner import Runner
from pytorch_tabnet.tab_model import TabNetClassifier
from pytorch_tabnet.pretraining import TabNetPretrainer

from misc.utils import getColIdx, getDimPerIdx
from torch import optim
import numpy as np
import os
import torch
import random
import pandas as pd
import logging
from types import SimpleNamespace
from typing import Tuple
from typing import List, Dict, Any, Union
from numpy.typing import NDArray
from runners.model_types import DreamquarkTabNet

class DreamquarkTabNetRunner(Runner):
    def __init__(self, 
                config: SimpleNamespace, 
                data: pd.DataFrame, 
                labels: pd.Series, 
                numeric_cols: List[str], 
                category_cols: List[str], 
                logger: logging.Logger
        ) -> None:
        del category_cols
        
        super().__init__(config, data, labels, numeric_cols, logger)
        _, self.category_idx = getColIdx(numeric_cols, self.data)
        self.category_dims = getDimPerIdx(self.category_idx, self.data)

        self.pretraining_ratio = None

        if self.config.model.fast_dev_run:
            self.config.model.max_epochs = 1
            self.config.optuna.n_trials = 1
        
        if self.config.runner_option.do_pretraining:
            self.pretrained_model_path = ""


    def set_hparams(self, 
                    hparams: Dict[str, Any]
        ) -> Dict[str, Any]:
        hparams['cat_idxs'] = self.category_idx
        hparams['cat_dims'] = self.category_dims
        hparams['n_a'] = hparams['n_d']
        
        hparams['verbose'] = self.config.model.verbose
        hparams['device_name'] = self.config.model.device
        hparams['seed'] = self.random_seed

        if hparams['scheduler_fn'] == 'StepLR':
            hparams['scheduler_fn'] = optim.lr_scheduler.StepLR
            hparams['scheduler_hparams'] = {
                                            'step_size' : 10,
                                            'gamma' : 0.9,
                                        }
        elif hparams['scheduler_fn'] == 'CosineAnnealingLR':
            hparams['scheduler_fn'] = optim.lr_scheduler.CosineAnnealingLR
            hparams['scheduler_hparams'] = {
                                            'T_max' : 10,
                                            'eta_min' : 0.001
                                        }

        if 'optimizer_hparams' not in hparams:
            hparams['optimizer_hparams'] = dict(lr = hparams['lr'])

            del hparams['lr']

        if 'pretraining_ratio' in hparams:
            self.pretraining_ratio = hparams['pretraining_ratio']
            del hparams['pretraining_ratio'] 

        return hparams

    def get_model(self, 
                    hparams: Dict[str, Any], 
                    for_pretraining: bool = True
        ) -> TabNetPretrainer or TabNetClassifier:
        hparams = self.set_hparams(hparams)
        if for_pretraining:
            return TabNetPretrainer(**hparams)
        else:
            return TabNetClassifier(**hparams)

    def fit_model(self, 
                    model: Union[TabNetPretrainer, TabNetClassifier], 
                    train_idx: np.array, 
                    test_idx: np.array, 
                    hparams: Dict[str, Any] = None, 
                    pseudo_data: pd.DataFrame = None, 
                    pseudo_label: pd.Series = None, 
                    fold_idx: int = 0
        ) -> TabNetClassifier:
        
        del fold_idx

        eval_metric = 'auc'
        if self.config.data.metric == 'accuracy_score':
            eval_metric = 'accuracy'
        elif self.config.data.metric == 'balanced_accuracy_score':
            eval_metric = 'balanced_accuracy'

        fit_hparams = {
            'X_train' : self.data.loc[train_idx, :].values,
            'y_train' : self.label[train_idx].values,
            'eval_set' : [(self.data.loc[test_idx, :].values, self.label[test_idx].values)],
            'eval_metric' : [eval_metric],
            'max_epochs' : self.config.model.max_epochs,
            'weights' : 1,
            'patience' : self.config.model.early_stopping_patience,
            'batch_size' : self.config.model.batch_size,
            'num_workers' : self.config.n_jobs,
        }

        if self.config.runner_option.do_pretraining:
            if self.pretrained_model_path == "":    
                pretrain_idx = np.delete(self.data.index.array, test_idx)
                model.fit(
                            X_train = self.data.loc[pretrain_idx, :].values,
                            eval_set = [self.data.loc[test_idx, :].values],
                            pretraining_ratio = self.pretraining_ratio,
                            max_epochs = self.config.model.max_epochs,
                            patience = self.config.model.pretrain_early_stopping_patience,
                            batch_size = self.config.model.batch_size,
                            num_workers = self.config.n_jobs
                )
                self.pretrained_model_path = self.save_model(model, f'temporary_ckpt_data/{self.start_time}-{self.config.data.target}-{self.config.runner_option.model}-pretrained')
                self.caching_pretrined_random_state()
            else:
                model = self.load_model(model_path=self.pretrained_model_path)
                self.load_pretrained_random_state()


            fit_hparams['from_unsupervised'] = model

        clf = self.get_model(hparams, False) 

        if pseudo_data is not None and pseudo_label is not None:
            fit_hparams['y_train'] = np.concatenate((self.label[train_idx].values, pseudo_label), axis=0)

            train_idx = np.concatenate((train_idx, pseudo_data.index.values))
            fit_hparams['X_train'] = self.data.loc[train_idx, :].values

        clf.fit(**fit_hparams)

        return clf

    def predict(self, 
                model: TabNetClassifier, 
                test_idx: np.array = None,
                X_test: pd.DataFrame = None, 
                return_prob: bool = False, 
                fold_idx: int = 0
        ) -> Union[Tuple[NDArray[np.int_], NDArray[np.int_]], np.array]:
        del fold_idx

        if test_idx is not None:
            X_test = self.data.loc[test_idx, :]
        
        preds = model.predict_proba(X_test.values)

        if return_prob:
            return preds.argmax(1), preds
        else:
            return preds.argmax(1)
        

    def save_model(self, 
                    model: Union[TabNetPretrainer, TabNetClassifier], 
                    saving_path: str = None, 
                    fold_idx: int = None
        ) -> str:
        if saving_path == None:
            saving_path = f"model/{self.start_time}-{self.config.data.target}-{self.config.runner_option.model}"
        
        if 'pretrained' in saving_path:
            torch.save(model, saving_path)
            return saving_path

        if fold_idx != None:
            saving_path += '-fold%d'
            model.save_model(saving_path % fold_idx)
        else:
            model.save_model(saving_path)
        
        return saving_path + '.zip'

        
        
    
    def load_model(self, 
                    model_path: str, 
                    fold_idx: int = None
        ) -> TabNetClassifier:
        if 'pretrained' in model_path:
            model = torch.load(model_path)
            return model

        model = TabNetClassifier()
        if fold_idx is not None:
            model.load_model(model_path % fold_idx)
        else:
            model.load_model(model_path)
        
        return model

    def clear_pretrained_model(self) -> None:
        self.clear_model_cache(self.pretrained_model_path)
        self.pretrained_model_path = ""
        
    def clear_model_cache(self, 
                            path: str
        ) -> None:
        try:
            for fold_idx in range(self.config.KFold.n_splits):
                if os.path.exists(path % fold_idx):
                    os.remove(path % fold_idx)
        except TypeError as e:
            pass
        if os.path.exists(path):
            os.remove(path)
    
    def caching_pretrined_random_state(self) -> None:
        self.random_rs = random.getstate()
        self.np_rs = np.random.RandomState().get_state()
        self.torch_rs = torch.get_rng_state()
        self.torch_cuda_rs = torch.cuda.get_rng_state()
        self.torch_cuda_rs_all = torch.cuda.get_rng_state_all()
        
        

    def load_pretrained_random_state(self) -> None:
        random.setstate(self.random_rs)
        np.random.set_state(self.np_rs)
        torch.set_rng_state(self.torch_rs)
        torch.cuda.set_rng_state(self.torch_cuda_rs)
        torch.cuda.set_rng_state_all(self.torch_cuda_rs_all)