import torch
import time
from torch import optim, nn
import lightning.pytorch as pl
from lightning.pytorch.callbacks import Callback
from lightning.pytorch.callbacks import TQDMProgressBar
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
import warnings

def objective_function(queue,configuration, plans, workspace, search_settings, debug=False):
    queue.put("failed")
    start_time = time.time()
    connector_choice=0
    data_preparation_choice=configuration[0]["data_preparation"]
    modeling_choice=configuration[0]["modeling"]
    batch_size=configuration[2]["batch_size"]

    other_configuration_strings=""

    exec(f"from {workspace} import data_preparation_{connector_choice}_{data_preparation_choice}", globals())

    exec(f"from {workspace} import modeling_{connector_choice}_{modeling_choice}", globals())

    exec(f"torch_model= modeling_{connector_choice}_{modeling_choice}.generate_model()", globals())

    exec(f"train_loader,val_loader = data_preparation_{connector_choice}_{data_preparation_choice}.generate_dataloader(batch_size={batch_size},{other_configuration_strings})", globals())

    early_stop_callback = EarlyStopping(
          monitor='ave_val_perf_epoch',
          min_delta=search_settings["min_delta"],
          patience=search_settings["patience"],
          verbose=True,
          mode="max" if search_settings["max"] else "min"
       )

    if debug:
        search_settings["skip"]=0
    else:
        warnings.filterwarnings("ignore" , category=UserWarning)
        pass

    lightning_model = pipeline(torch_model, configuration , plans , workspace , search_settings, debug=debug )

    if debug:
        trainer = pl.Trainer(max_epochs=1, limit_train_batches=5, limit_val_batches=5, callbacks=[custom_callbacks(), custom_progressbar()])
    else:
        trainer = pl.Trainer(max_epochs=search_settings["max_epochs"], callbacks=[custom_callbacks(), custom_progressbar(), early_stop_callback],precision="16-mixed",enable_checkpointing=search_settings["check_pointing"])


    trainer.fit(model=lightning_model, train_dataloaders=train_loader,val_dataloaders=val_loader)
    current_time = time.time()
    elapsed_time = current_time - start_time



    logged_metrics = trainer.logged_metrics
    ave_val_perf_epoch = float(logged_metrics['ave_val_perf_epoch'])
    print("Average validation performance:", ave_val_perf_epoch)
    print("Elapsed time:", elapsed_time)



    if search_settings["max"]:
        queue.put(-ave_val_perf_epoch)
        return -ave_val_perf_epoch
    else:
        queue.put(ave_val_perf_epoch)
        return ave_val_perf_epoch



class pipeline(pl.LightningModule):
    def __init__(self , torch_model , configuration , plans , workspace , search_settings , debug=False):


        super().__init__()

        self.model=torch_model.to("cuda")

        self.connector_choice=0
        self.workspace=workspace
        exec(f"from {self.workspace}.post_processing_{self.connector_choice}_0 import generate_evaluation", globals())

        self.basic_task_type_2 = plans[self.connector_choice]["connector"][1]

        self.input_number=len(plans[self.connector_choice]["connector"][4]["input"])
        self.output_number=len(plans[self.connector_choice]["connector"][4]["output"])

        self.loss_choice=0
        self.optimizer_choice=configuration[2]["optimizer"]
        self.momentum=configuration[2]["momentum"]
        self.learning_rate=configuration[2]["learning_rate"]
        self.weight_decay=configuration[2]["weight_decay"]
        self.scheduler=configuration[2]["scheduler"]

        self.debug=debug

        self.skip = search_settings["skip"]
        self.skip_count=None
        self.total_steps=0
        self.max=search_settings["max"]
        self.min_delta_step=search_settings["min_delta_step"]
        self.patience_step=search_settings["patience_step"]

        self.patience_counter=0
        self.validation_performance_list = []
        self.training_performance_list = []
        self.total_train_loss = 0.0
        self.train_batch_count = 0
        self.total_train_performance=0.0
        self.total_val_loss = 0.0
        self.val_batch_count = 0
        self.total_val_performance=0.0
        self.validation_flag = False

    def training_step(self, batch, batch_idx):

        inputs = batch[:self.input_number]
        outputs = batch[self.input_number :self.input_number + self.output_number]

        inputs_cuda = [tensor.to("cuda") if isinstance(tensor , torch.Tensor) else tensor for tensor in inputs]
        y = [tensor.to("cuda") if isinstance(tensor , torch.Tensor) else tensor for tensor in outputs]
        y=y[0]

        y_predict = self.model(*inputs_cuda)

        if self.debug:
            print("Before loss:","y_predict",y_predict, "y_predict_dtype", y_predict.dtype, "y", y, "y_dtype", y.dtype)

        if self.basic_task_type_2== "multi-class classification":
            loss_list = [nn.CrossEntropyLoss()(y_predict , y).to("cuda")]
        elif self.basic_task_type_2 == "single-output regression":
            loss_list = [nn.MSELoss()(y_predict,y), nn.L1Loss()(y, y_predict), nn.SmoothL1Loss()(y, y_predict)]
        elif self.basic_task_type_2 == "binary classification":
            loss_list = [nn.BCELoss()(y.float(), y_predict.float()),
                         nn.BCEWithLogitsLoss()(y.float(), y_predict.float())]
        elif self.basic_task_type_2== "multi-label classification":
            loss_list = [nn.BCELoss()(y.float(), y_predict.float()),
                         nn.BCEWithLogitsLoss()(y.float(), y_predict.float())]
        elif self.basic_task_type_2== "multi-output regression":
            loss_list = [nn.MSELoss()(y_predict,y), nn.L1Loss()(y, y_predict), nn.SmoothL1Loss()(y, y_predict)]

        loss = loss_list[self.loss_choice]

        with torch.cuda.device(0):
            y_predict=y_predict.detach().cpu()

        y=y.detach().cpu()
        performance =generate_evaluation(y_predict, y)

        self.train_batch_count += 1

        if self.skip_count is not None:
            if self.train_batch_count >= self.skip_count:

                self.total_train_loss += loss.item()
                avg_loss = self.total_train_loss / (self.train_batch_count - self.skip_count + 1)

                self.total_train_performance += performance
                avg_performance = self.total_train_performance / (self.train_batch_count - self.skip_count + 1)
                self.log('ave_train_loss', avg_loss, on_epoch=True, prog_bar=True, on_step=True)
                self.log('ave_train_perf', avg_performance, on_epoch=True, prog_bar=True, on_step=True)

        return loss

    def on_train_start(self):
        print("\n")
        print("Begin training")

    def on_train_epoch_start(self):
        self.total_train_loss = 0.0
        self.train_batch_count = 0
        self.total_train_performance=0.0

    def on_validation_epoch_start(self):
        self.total_val_loss = 0.0
        self.val_batch_count = 0
        self.total_val_performance=0.0

    def validation_step(self, batch, batch_idx):

        inputs = batch[:self.input_number]
        outputs = batch[self.input_number :self.input_number + self.output_number]

        inputs_cuda = [tensor.to("cuda") if isinstance(tensor , torch.Tensor) else tensor for tensor in inputs]
        y = [tensor.to("cuda") if isinstance(tensor , torch.Tensor) else tensor for tensor in outputs]
        y=y[0]

        y_predict = self.model(*inputs_cuda)

        if self.debug:
            print("before loss","y_predict",y_predict, "y_predict_dtype", y_predict.dtype, "y", y, "y_dtype", y.dtype)

        if self.basic_task_type_2== "multi-class classification":
            loss_list = [nn.CrossEntropyLoss()(y_predict , y).to("cuda")]
        elif self.basic_task_type_2 == "single-output regression":
            loss_list = [nn.MSELoss()(y_predict,y), nn.L1Loss()(y, y_predict), nn.SmoothL1Loss()(y, y_predict)]
        elif self.basic_task_type_2 == "binary classification":
            loss_list = [nn.BCELoss()(y.float(), y_predict.float()),
                         nn.BCEWithLogitsLoss()(y.float(), y_predict.float())]
        elif self.basic_task_type_2== "multi-label classification":
            loss_list = [nn.BCELoss()(y.float(), y_predict.float()),
                         nn.BCEWithLogitsLoss()(y.float(), y_predict.float())]
        elif self.basic_task_type_2== "multi-output regression":
            loss_list = [nn.MSELoss()(y_predict,y), nn.L1Loss()(y, y_predict), nn.SmoothL1Loss()(y, y_predict)]

        loss = loss_list[self.loss_choice]

        with torch.cuda.device(0):
            y_predict=y_predict.detach().cpu()

        y=y.detach().cpu()
        performance =generate_evaluation(y_predict, y)


        self.val_batch_count += 1

        if self.skip_count is not None:
            if self.val_batch_count >= self.skip_count:
                self.total_val_loss += loss.item()
                avg_loss = self.total_val_loss / (self.val_batch_count - self.skip_count + 1)

                self.total_val_performance += performance
                avg_performance = self.total_val_performance / (self.val_batch_count - self.skip_count + 1)
                self.log('ave_val_loss', avg_loss, on_epoch=True, prog_bar=True, on_step=True)
                self.log('ave_val_perf', avg_performance, on_epoch=True, prog_bar=True, on_step=True)

        return loss

    def configure_optimizers(self):

        if self.optimizer_choice=="adam":
            optimizer = optim.Adam(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
        elif self.optimizer_choice=="sgd":
            optimizer = optim.SGD(self.parameters(), lr=self.learning_rate, momentum=self.momentum, weight_decay=self.weight_decay)
        elif self.optimizer_choice=="adamw":
            optimizer = optim.AdamW(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)

        if self.scheduler=="plateau":
            scheduler=optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=10, verbose=True, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08)
        elif self.scheduler=="cosine":
            scheduler=optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=0, last_epoch=-1)

        return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "ave_val_perf_epoch"}



class custom_callbacks(Callback):
    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        if batch_idx >= pl_module.skip_count:
            average_training_performance = float(trainer.logged_metrics['ave_train_perf_step'])
            pl_module.training_performance_list.append(average_training_performance)
            if len(pl_module.training_performance_list)>=2:

                if pl_module.max:
                    if ((pl_module.training_performance_list[-1] - pl_module.training_performance_list[-2]) < pl_module.min_delta_step *
                            pl_module.training_performance_list[-2]):
                            pl_module.patience_counter+=1
                    else:
                        pl_module.patience_counter=0
                else:
                    if ((pl_module.training_performance_list[-2] - pl_module.training_performance_list[-1]) < pl_module.min_delta_step *
                            pl_module.training_performance_list[-2]):
                            pl_module.patience_counter+=1
                    else:
                        pl_module.patience_counter=0
            if pl_module.patience_counter >=  pl_module.patience_step_count:
                print("\n")
                print("Early stopping due to slow improvement per step is triggered.")
                trainer.should_stop = True

    def on_train_epoch_start(self, trainer, pl_module):
        pl_module.total_steps=len(trainer.train_dataloader)
        pl_module.skip_count=int(pl_module.skip*pl_module.total_steps)
        pl_module.patience_step_count=int(pl_module.patience_step*pl_module.total_steps)

    def on_validation_epoch_end(self, trainer, pl_module):
        try:
            pl_module.validation_performance_list.append(trainer.logged_metrics['ave_val_perf_epoch'])
            pl_module.validation_flag=True
        except:
            pass


class custom_progressbar(TQDMProgressBar):
    def init_validation_tqdm(self):
        print("\nBegin validation")
        self.disable()
        bar = super().init_validation_tqdm()
        return bar
    def on_validation_end(self,trainer, pl_module):
        try:
            print(f"Average validation loss for the epoch: {trainer.logged_metrics['ave_val_loss_epoch']}")
            print(f"Average validation performance for the epoch: {trainer.logged_metrics['ave_val_perf_epoch']}")
        except:
            pass
        self.enable()
