import comet_ml

import sys
import os
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
if project_root not in sys.path:
    sys.path.insert(0, project_root)
    
from utils import Config
from utils import RecordManager
import utils
import data
import model

import argparse
import torch
import lightning as L 
from lightning.pytorch.loggers import CometLogger
from comet_ml.integration.pytorch import log_model

class ExperimentFactory:
    def __init__(self, config_file_path="", basic_config='experiment/default.json', model_folder='dataset/exp_result/models', eval=False, upload=False):
        self.config_file_path = config_file_path
        self.basic_config = basic_config
        self.model_folder = model_folder
        self.eval = eval
        self.upload = upload
        if config_file_path:
            self.n_rl = False
        else:
            self.n_rl = True
        self.config = Config(basic_config)
        self.record_mn = RecordManager()
        
    def create_dataloaders(self,features=None):
        dataset_params = self.config.get_dict('data.dataset')
        dataloader_params = self.config.get_dict('data.dataloader')
        dataset_class = getattr(data, dataset_params['name'])
        dataloader_class = getattr(data, dataloader_params['name'])
        if features:
            train_ds = dataset_class(features=features, **dataset_params['train_data'])
            valid_ds = dataset_class(features=features, **dataset_params['valid_data'])
            test_ds = dataset_class(features=features, **dataset_params['test_data'])
            print(f"Choosen features: {features}")
        else:
            train_ds = dataset_class(**dataset_params['train_data'])
            valid_ds = dataset_class(**dataset_params['valid_data'])
            test_ds = dataset_class(**dataset_params['test_data'])
        ds = [train_ds, valid_ds, test_ds]
        dl = dataloader_class(datasets=ds, **dataloader_params)
        self.data_module = dl 
        self.train_dataloader = dl.train_dataloader()
        self.val_dataloader = dl.val_dataloader()
        self.test_dataloader = dl.test_dataloader()
        
        print("Create dataloader successfully.")

        
    def create_model(self):
        self.model_params = self.config.get_dict('model')
        self.model_name = self.model_params['name']
        model_class = getattr(model, self.model_params['name'])

        if not self.eval:
            model_instance = model_class(**self.model_params)
            self.model = model_instance
            print(f"Load {self.model_name} model successfully.")
        else:
            ckpt_path = os.path.join(self.model_folder, self.logger_params['experiment_name'] + ".ckpt")
            self.model = model_class.load_from_checkpoint(ckpt_path, **self.model_params)
            print(f"Load {self.model_name} model from checkpoint successfully.")

    
    def create_logger(self):
        self.logger_params = self.config.get_dict('comet')
        if self.eval:
            self.exp_info = self.record_mn.get_exp_info_by_name(self.logger_params['experiment_name'])
            self.logger = CometLogger(experiment_key=self.exp_info[2], **self.logger_params)
            print("Create Comet logger from existing experiment successfully.")
        else:
            self.logger = CometLogger(**self.logger_params)
            print("Create Comet logger successfully.")
        
    def create_trainer(self):
        trainer_params = self.config.get_dict('trainer')
        self.create_logger()
        self.trainer = L.Trainer(
            strategy='ddp_find_unused_parameters_true',
            # strategy='ddp_notebook',
                                 logger=self.logger,
                                 **trainer_params)
        print("Create trainer successfully.")
        
    def start_training(self,config_file_path):
        self.config.update_config(config_file_path)
        self.create_dataloaders()
        self.create_trainer()
        self.create_model()
        self.logger.experiment.log_asset(self.basic_config)
        self.logger.experiment.log_asset(config_file_path)
        print(f"{os.getpid()} start training:")
        self.trainer.fit(self.model, 
                         datamodule=self.data_module)
        self.trainer.test(self.model, self.test_dataloader)
        # log_model(self.logger.experiment, self.model, model_name=self.model_name)
        # self.logger.finalize(status=None)

    def fraud_detection(self):
        detection_params = self.config.get_dict('detection')
        detection_class = getattr(utils, detection_params['name'])
        self.create_dataloaders(detection_params['features'])
        if self.n_rl:
            detect_result= detection_class(datamodule=self.data_module, **detection_params)
        else:
            detect_result = detection_class(model=self.model, datamodule=self.data_module, **detection_params)
        evaluation = getattr(utils, detection_params['eval_matrix'])
        results = evaluation(detect_result,detection_params['name'])
        if self.upload:
            self.logger.log_metrics(results)
            print("Upload all the results to the existing comet experiment.")
        print(f"{os.getpid()} finish fraud detection evaluation.")
        
    def evaluate(self,config_file_path):
        print("Start evaluating process:")
        if self.n_rl:
            print("Disable representation learning.")
            self.fraud_detection()
        else:
            self.config.update_config(config_file_path)
            self.create_logger()
            self.create_model()
            self.model.eval()
            self.fraud_detection()
            self.logger.finalize(status=None)
        
    def run(self):
        torch.set_float32_matmul_precision('high')
        if self.eval:
            self.evaluate(self.config_file_path)
        else:
            self.start_training(self.config_file_path)

        
if __name__ == "__main__":
    
    parser = argparse.ArgumentParser(description='Fraud Detection Model')
    parser.add_argument('-c','--config_file_path', type=str, help='Path to the config file')
    parser.add_argument('-b','--basic_config', type=str, default='experiment/default.json', help='Path to the basic config file')
    parser.add_argument('-e','--eval', action='store_true', help='Whether to evaluate the model without training')
    parser.add_argument('-u','--upload', action='store_true', help='Whether to upload the evalutation results to comet')
    args = parser.parse_args()
    
    mf = ExperimentFactory(**vars(args))
    mf.run()