from torch.utils.data.distributed import DistributedSampler
from utils.misc import get_rank, get_world_size, is_dist_avail_and_initialized
import onnxruntime as ort


class ConfigBuilder(object):
    """
    Configuration Builder.

    """
    def __init__(self, **params):
        """
        Set the default configuration for the configuration builder.

        Parameters
        ----------
        
        params: the configuration parameters.
        """
        super(ConfigBuilder, self).__init__()
        self.model_params = params.get('model', {})
        self.dataset_params = params.get('dataset', {'data_dir': 'data'})
        self.dataloader_params = params.get('dataloader', {})

        self.logger = params.get('logger', None)
    
    def get_model(self,model_params = None):
        """
        Get the model from configuration.

        Parameters
        ----------
        
        model_params: dict, optional, default: None. If model_params is provided, then use the parameters specified in the model_params to build the model. Otherwise, the model parameters in the self.params will be used to build the model.
        """
        '''
        from models.FNP import FNP
        from models.ConvCNP import ConvCNP
        from models.Adas import Adas
        '''
        
        from models.diffDA import diffDA
        
        if model_params is None:
            model_params = self.model_params
        type = model_params.get('type', 'FNP')
        print('model_params',model_params)
        if type == 'diffDA':
            model = diffDA(**model_params)
        elif type == 'diffDPO':
            model = diffDPO(**model_params)
        elif type == 'ladiffDA':
            model = ladiffDA(**model_params)
        else:
            raise NotImplementedError('Invalid model type.')
        
        return model
    
    def get_forecast(self, local_rank):

        # Set the behavier of onnxruntime
        options = ort.SessionOptions()
        options.enable_cpu_mem_arena=False
        options.enable_mem_pattern = False
        options.enable_mem_reuse = False
        # Increase the number for faster inference and more memory consumption
        options.intra_op_num_threads = 1

        # Set the behavier of cuda provider
        cuda_provider_options = {'device_id': local_rank, 'arena_extend_strategy':'kSameAsRequested',}

        # Initialize onnxruntime session for Pangu-Weather Models
        ort_session = ort.InferenceSession('../ddpm/models/FengWu_128x256.onnx', sess_options=options, providers=[('CUDAExecutionProvider', cuda_provider_options)])
        
        for i, input_info in enumerate(ort_session.get_inputs()):
            print(f"Input {i} name: {input_info.name}, shape: {input_info.shape}, type: {input_info.type}")
        for i, output_info in enumerate(ort_session.get_outputs()):
            print(f"Output {i} name: {output_info.name}, shape: {output_info.shape}, type: {output_info.type}")
        return ort_session

    def get_dataset(self, dataset_params = None, split = 'train'):
        """
        Get the dataset from configuration.

        Parameters
        ----------
        
        dataset_params: dict, optional, default: None. If dataset_params is provided, then use the parameters specified in the dataset_params to build the dataset. Otherwise, the dataset parameters in the self.params will be used to build the dataset;
        
        split: str in ['train', 'test'], optional, default: 'train', the splitted dataset.

        Returns
        -------
        
        A torch.utils.data.Dataset item.
        """
        
        from datasets.era5_npy_f32 import era5_npy_f32
        from datasets.era5_latent import era5_latent
        
        if dataset_params is None:
            dataset_params = self.dataset_params
        dataset_params = dataset_params.get(split, None)
        print(dataset_params)
        '''
        if dataset_params is None:
            return None
        dataset = era5_npy_f32(split = split, **dataset_params)
        '''
        #Jing-An
        if split == "train":
            dataset = era5_latent(data_dir = "/mnt/hwfile/48era5_vaelatent_79-15_mode")
        if split == "valid":
            dataset = era5_npy_f32(split = split, **dataset_params)
            #era5_latent(data_dir = "/mnt/hwfile/48era5_vaelatent_16-17_mode")
        if split == "test":
            dataset = era5_npy_f32(split = split, **dataset_params)
            #dataset = era5_latent(data_dir = "/mnt/petrelfs/sunjingan/DiffDpo-DA/ladiff/era5_latent_18-18")
        
        return dataset
    
    def get_sampler(self, dataset, split = 'train', drop_last=False):
        shuffle = False
        if split=='valid':
            shuffle = True
            
        if is_dist_avail_and_initialized():
            rank = get_rank()
            num_gpus = get_world_size()
        else:
            rank = 0
            num_gpus = 1
        sampler = DistributedSampler(dataset, rank=rank, shuffle=False, num_replicas=num_gpus, drop_last=drop_last)

        return sampler
   

    def get_dataloader(self, dataset_params = None, split = 'train', batch_size = 1, dataloader_params = None, drop_last = True):
        """
        Get the dataloader from configuration.

        Parameters
        ----------
        
        dataset_params: dict, optional, default: None. If dataset_params is provided, then use the parameters specified in the dataset_params to build the dataset. Otherwise, the dataset parameters in the self.params will be used to build the dataset;
        
        split: str in ['train', 'test'], optional, default: 'train', the splitted dataset;
        
        batch_size: int, optional, default: None. If batch_size is None, then the batch size parameter in the self.params will be used to represent the batch size (If still not specified, default: 4);
        
        dataloader_params: dict, optional, default: None. If dataloader_params is provided, then use the parameters specified in the dataloader_params to get the dataloader. Otherwise, the dataloader parameters in the self.params will be used to get the dataloader.

        Returns
        -------
        
        A torch.utils.data.DataLoader item.
        """
        from torch.utils.data import DataLoader

        # if split != "train":
        #     drop_last = True
        if dataloader_params is None:
            dataloader_params = self.dataloader_params
        dataset = self.get_dataset(dataset_params, split)
        print('dataset:',dataset)
        if dataset is None:
            return None
        sampler = self.get_sampler(dataset, split, drop_last=drop_last)
        print("batch_size",batch_size)

        return DataLoader(
            dataset,
            batch_size = batch_size,
            sampler=sampler,
            drop_last=drop_last,
            **dataloader_params
        )


def get_optimizer(model, optimizer_params = None, resume = False, resume_lr = None):
    """
    Get the optimizer from configuration.
    
    Parameters
    ----------
    
    model: a torch.nn.Module object, the model.
    
    optimizer_params: dict, optional, default: None. If optimizer_params is provided, then use the parameters specified in the optimizer_params to build the optimizer. Otherwise, the optimizer parameters in the self.params will be used to build the optimizer;
    
    resume: bool, optional, default: False, whether to resume training from an existing checkpoint;

    resume_lr: float, optional, default: None, the resume learning rate.
    
    Returns
    -------
    
    An optimizer for the given model.
    """
    from torch.optim import SGD, Adam, AdamW
    type = optimizer_params.get('type', 'AdamW')
    params = optimizer_params.get('params', {})

    if resume:
        network_params = [{'params': model.parameters(), 'initial_lr': resume_lr}]
        params.update(lr = resume_lr)
    else:
        network_params = model.parameters()
    if type == 'SGD':
        optimizer = SGD(network_params, **params)
    elif type == 'Adam':
        optimizer = Adam(network_params, **params)
    elif type == 'AdamW':
        optimizer = AdamW(network_params, **params)
    else:
        raise NotImplementedError('Invalid optimizer type.')
    return optimizer

def get_lr_scheduler(optimizer, lr_scheduler_params = None, resume = False, resume_epoch = None, total_steps = None):
    """
    Get the learning rate scheduler from configuration.
    
    Parameters
    ----------
    
    optimizer: an optimizer;
    
    lr_scheduler_params: dict, optional, default: None. If lr_scheduler_params is provided, then use the parameters specified in the lr_scheduler_params to build the learning rate scheduler. Otherwise, the learning rate scheduler parameters in the self.params will be used to build the learning rate scheduler;

    resume: bool, optional, default: False, whether to resume training from an existing checkpoint;

    resume_epoch: int, optional, default: None, the epoch of the checkpoint.
    
    Returns
    -------

    A learning rate scheduler for the given optimizer.
    """
    from torch.optim.lr_scheduler import MultiStepLR, ExponentialLR, CyclicLR, CosineAnnealingLR, StepLR, OneCycleLR
    type = lr_scheduler_params.get('type', '')
    params = lr_scheduler_params.get('params', {})
    if resume:
        params.update(last_epoch = resume_epoch)
    if type == 'MultiStepLR':
        scheduler = MultiStepLR(optimizer, **params)
    elif type == 'ExponentialLR':
        scheduler = ExponentialLR(optimizer, **params)
    elif type == 'CyclicLR':
        scheduler = CyclicLR(optimizer, **params)
    elif type == 'CosineAnnealingLR':
        scheduler = CosineAnnealingLR(optimizer, **params)
    elif type == 'StepLR':
        scheduler = StepLR(optimizer, **params)
    elif type == 'OneCycleLR':
        scheduler = OneCycleLR(optimizer, total_steps=total_steps, **params)
    elif type == '':
        scheduler = None
    else:
        raise NotImplementedError('Invalid learning rate scheduler type.')
    return scheduler
