import torch
import time
from tqdm import tqdm

from util_scripts.wandb_logger import WandbLogger
from util_scripts.train_callbacks import ModelSaverLoaderCallback

class ModelTrainer():
    def __init__(self, model, dataset, data_module, opt):
        self.opt = opt
        self.model = model
        self.dataset = dataset
        self.data_module = data_module
        # device
        self.device = opt.device
        # optimizer and scheduler
        self.optimizer, self.scheduler = self.configure_optimizers()
        # logger
        self.logger = WandbLogger(opt)
        # callback
        self.callback = ModelSaverLoaderCallback(opt.result_path, 'model', opt=opt)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.opt.learning_rate)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=20, factor=0.1, verbose=True)
        return optimizer, scheduler
    
    def training_step(self, batch, epoch):
        # Forward pass through the encoders
        if self.dataset in ['mosi', 'mosei']:
            batch_X, batch_Y, _ = batch[0], batch[1], batch[2]

            _, text, audio, vision = batch_X
            target_data, mask_data = batch_Y
            
            target_data = target_data.squeeze(-1).to(self.device)
            mask_data = mask_data.squeeze(-1).to(self.device)
            data = [text.to(self.device), audio.to(self.device), vision.to(self.device)]

        elif self.dataset in ['mmimdb', 'food101', 'hatememes']:
            batch_X, batch_Y = batch
            image, text = batch_X
            target_data, mask_data = batch_Y

            target_data = target_data.float().squeeze(-1).to(self.device)
            mask_data = mask_data.squeeze(-1).to(self.device)
            data = [text.to(self.device), image.to(self.device)]
            
        _, tqdm_dict = self.model.training_step(data, target_data, mask_data, self.opt, epoch)

        return tqdm_dict
    
    def training_epoch_end(self, epoch, outputs):
        log_keys = list(outputs[0].keys())
        for log_key in log_keys:
            avg_batch_log = (
                torch.stack(
                    [
                        outputs[batch_output_idx][log_key]
                        for batch_output_idx in range(len(outputs))
                    ]
                )
                .mean()
            )
            self.logger.add_log(f"train/{log_key}", avg_batch_log)

        self.logger.write_log(epoch)

    def validation_step(self, batch, epoch):
        if self.dataset in ['mosi', 'mosei']:
            # Forward pass through the encoders
            batch_X, batch_Y, _ = batch[0], batch[1], batch[2]

            _, text, audio, vision = batch_X
            target_data, mask_data = batch_Y

            target_data = target_data.float().squeeze(-1).to(self.device)
            mask_data = mask_data.squeeze(-1).to(self.device)
            data = [text.to(self.device), audio.to(self.device), vision.to(self.device)]

        elif self.dataset in ['mmimdb', 'food101', 'hatememes']:
            batch_X, batch_Y = batch
            image, text = batch_X
            target_data, mask_data = batch_Y

            target_data = target_data.float().squeeze(-1).to(self.device)
            mask_data = mask_data.squeeze(-1).to(self.device)
            data = [text.to(self.device), image.to(self.device)]

        output_dict = self.model.validation_step(data, target_data, mask_data, self.opt, epoch)
        return output_dict
    
    def validation_epoch_end(self, epoch, outputs):
        log_keys = list(outputs[0].keys())
        for log_key in log_keys:
            avg_batch_log = (
                torch.stack(
                    [
                        outputs[batch_output_idx][log_key]
                        for batch_output_idx in range(len(outputs))
                    ]
                )
                .mean()
            )
            self.logger.add_log(f"val/{log_key}", avg_batch_log)
            if log_key == self.opt.ckpt_metric:
                self.callback.save_cpkt(self.model, avg_batch_log)

        self.logger.write_log(epoch)

    def param_toggle(self, reconstruct=False):
        if not reconstruct:
            for param in self.model.parameters():
                param.requires_grad = True
        else:
            for param in self.model.parameters():
                param.requires_grad = False
            for param in self.model.get_reconstruct_parameters():
                param.requires_grad = True

    def fit(self):
        start_time = time.time()
        for epoch in range(self.opt.n_epochs):
            # Train
            self.model.train()
            outputs, val_outputs = [], []
            for batch in tqdm(self.data_module.train_dataloader()):
                self.optimizer.zero_grad()
                output = self.training_step(batch, epoch)
                outputs.append(output)
                # train unique branch
                self.param_toggle(reconstruct=True)
                output['rec_loss'].backward(retain_graph=True)
                # train shared branch
                self.param_toggle(reconstruct=False)
                if output["loss"] != 0:
                    output['loss'].backward()
                # update parameters
                self.optimizer.step()

            self.training_epoch_end(epoch, outputs)
            
            # Validation
            self.model.eval()
            with torch.no_grad():
                for batch in tqdm(self.data_module.val_dataloader()):
                    output = self.validation_step(batch, epoch)
                    val_outputs.append(output)
            self.validation_epoch_end(epoch, val_outputs)
            self.scheduler.step(output["loss"])

        print(f"Training time: {time.time() - start_time}s")