from omegaconf import DictConfig
from TimeVaryingCausalModel import TimeVaryingCausalModel
from pytorch_lightning import seed_everything

import torch
from torch import nn

# Dataset
from utils import CancerDataset
from torch.utils.data import DataLoader, Dataset, Subset

class RMSNtrain(TimeVaryingCausalModel):
    
    def __init__(self, args: DictConfig):
        super().__init__()
        
        args.dataset.treatment_mode = "multilabel"
        
        self.init_specific(args)
    
        # off automatic optimization
        self.automatic_optimization = False

    def init_specific(self, args):
        self.num_treatments     = args.dataset.dim_treatments 
        self.num_dosage_samples = args.dataset.dim_n_dosage_samples
                    
        # max cancer volume
        self.MAX_CANCER_VOLUME    = args.dataset.max_cancer_volume
        
        # projection_horizon
        self.projection_horizon   = args.dataset.projection_horizon
                
        # treatment_mode
        self.treatment_mode      = args.dataset.treatment_mode
        
        
        if self.treatment_mode == 'multilabel':
            self.dim_treatments  = 2
        else:
            self.dim_treatments  = args.dataset.dim_treatments
            
        print(self.dim_treatments)
            
        self.dim_dosages         = args.dataset.dim_dosages
        self.dim_static_features = args.dataset.dim_static_features
        self.dim_outcome         = args.dataset.dim_outcomes  
        
        # device
        self.device_ori = f'cuda:{args.exp.gpu}'
        print("use gpu:", self.device_ori)

        # -------------------------------------------------------
        # Learning 
        # -------------------------------------------------------
        # epochs
        self.max_epochs       = args.exp.max_epochs
                    
        # optimiser parameters
        sub_args = args.model[self.model_type]
        self.optimizer_cls = sub_args.optimizer.optimizer_cls
        self.learning_rate = sub_args.optimizer.learning_rate
        self.weight_decay  = sub_args.optimizer.weight_decay
        self.lr_scheduler  = sub_args.optimizer.lr_scheduler
        
        # batch_size
        self.train_batch_size = sub_args.batch_size
        self.max_grad_norm    = sub_args.max_grad_norm
    
    def train_RMSN(self):
        #
        seed_everything(10)
        max_epoch   = self.max_epochs
        
        ###########################################################
        # get optimizers                                          #
        ###########################################################
        if self.lr_scheduler:
            optimizer, lr_scheduler = self.configure_optimizers()
        else:
            optimizer = self.configure_optimizers()
                
        ###########################################################
        # lerning                                                 #
        ###########################################################
        all_loss = {}
        for epoch in range(1, max_epoch + 1):

            epoch_loss = 0.0
            # -----------------------------------------------------------------------------------
            # optimize
            # -----------------------------------------------------------------------------------
            for batch in self.train_dataloader():
                loss = self.training_step(batch = batch)  
                
                # optimize generator parameters
                optimizer.zero_grad()
                # backward
                loss.backward()
                
                # clipping
                nn.utils.clip_grad_norm_(self.parameters(), self.max_grad_norm)
                
                # optimize
                optimizer.step()
                 
                epoch_loss += (loss * batch['active_entries'].sum()).to('cpu').detach().numpy().copy()

            all_loss[epoch] = epoch_loss

            if self.lr_scheduler:
                lr_scheduler.step()
                            
            # Loss
            print("epoch: {:>3d}\tLoss: {:>10.0f}".format(epoch, epoch_loss))
            
        return all_loss

            
    # --------------------------------------------------------------------------------------
    # Optimizers
    # --------------------------------------------------------------------------------------            
    def configure_optimizers(self):
        optimizer = self._get_optimizers(list(self.named_parameters()))

        if self.lr_scheduler:
            lr_scheduler = self._get_lr_schedulers(optimizer)
            return optimizer, lr_scheduler

        return optimizer

            
    # --------------------------------------------------------------------------------------
    # DataLoader
    # --------------------------------------------------------------------------------------  
    def train_dataloader(self) -> DataLoader:
        return DataLoader(CancerDataset(self.train_f), 
                          shuffle = True, 
                          batch_size = self.train_batch_size, 
                          drop_last = True)
    
    def val_dataloader(self) -> DataLoader:
        return DataLoader(CancerDataset(self.train_f), 
                          shuffle = False, 
                          batch_size = 1000, 
                          drop_last = False)
    
    def test_step(self):
        self.eval()
        rmse_cf_test_multi, _ = self.get_multi_step_counterfactual_rmse()
        return rmse_cf_test_multi
    
    def validation_step(self):
        self.eval()
        # Normalized RMSE
        if self.isDecoder:
            rmse_f_valid_multi, _ = self.get_multi_step_factual_rmse()
            print("\t validation   F: ", end="")
            for tau, value in rmse_f_valid_multi.items():
                print("{}: {:.3f}".format(tau, value), end=', ')
            print("")
            #
            return rmse_f_valid_multi
        else:
            rmse_f_valid = self.get_one_step_factual_rmse()
            rmse_cf_test = self.get_one_step_factual_rmse("test_cf")
            print("\t validation: {:.5f}  test: {:.5f}".format(rmse_f_valid, rmse_cf_test))
            
            return rmse_f_valid