"""Implements dataloaders for AFFECT data."""
import logging
import pickle
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
__all__ = ['MMDataLoader']
logger = logging.getLogger('MMSA')


class MMDataset(Dataset):
    def __init__(self, args, mode='train', use_video=True, use_audio=True, use_text=True, sample_rate=None):
        self.mode = mode
        self.args = args
        self.use_video = use_video
        self.use_audio = use_audio
        self.use_text = use_text
        self.sample_rate = sample_rate
        DATASET_MAP = {
            'mosi': self.__init_mosi,
            'mosei': self.__init_mosei,
        }
        DATASET_MAP[args.dataset_name]()

        if mode == 'train' and sample_rate is not None:
            self.__stratified_sampling()

    def __init_mosi(self):
        with open(self.args.featurePath, 'rb') as f:
            data = pickle.load(f)
        if 'use_bert' in self.args and self.args.use_bert:
            self.text = data[self.mode]['text_bert'].astype(np.float32)
        else:
            self.text = data[self.mode]['text'].astype(np.float32)
        self.vision = data[self.mode]['vision'].astype(np.float32)
        self.audio = data[self.mode]['audio'].astype(np.float32)
        self.raw_text = data[self.mode]['raw_text']
        self.ids = data[self.mode]['id']

        if self.args.feature_T != "":
            with open(self.args.feature_T, 'rb') as f:
                data_T = pickle.load(f)
            if 'use_bert' in self.args and self.args.use_bert:
                self.text = data_T[self.mode]['text_bert'].astype(np.float32)
                self.args.feature_dims[0] = 768
            else:
                self.text = data_T[self.mode]['text'].astype(np.float32)
                self.args.feature_dims[0] = self.text.shape[2]
        if self.args.feature_A != "":
            with open(self.args.feature_A, 'rb') as f:
                data_A = pickle.load(f)
            self.audio = data_A[self.mode]['audio'].astype(np.float32)
            self.args.feature_dims[1] = self.audio.shape[2]
        if self.args.feature_V != "":
            with open(self.args.feature_V, 'rb') as f:
                data_V = pickle.load(f)
            self.vision = data_V[self.mode]['vision'].astype(np.float32)
            self.args.feature_dims[2] = self.vision.shape[2]

        self.labels = {
            'M': np.array(data[self.mode]['regression_labels']).astype(np.float32)
        }

        logger.info(f"{self.mode} samples: {self.labels['M'].shape}")

        if not self.args.need_data_aligned:
            if self.args.feature_A != "":
                self.audio_lengths = list(data_A[self.mode]['audio_lengths'])
            else:
                self.audio_lengths = data[self.mode]['audio_lengths']
            if self.args.feature_V != "":
                self.vision_lengths = list(data_V[self.mode]['vision_lengths'])
            else:
                self.vision_lengths = data[self.mode]['vision_lengths']
        self.audio[self.audio == -np.inf] = 0

        if 'need_normalized' in self.args and self.args.need_normalized:
            self.__normalize()
    
    def __init_mosei(self):
        return self.__init_mosi()
    
    def __stratified_sampling(self):
        rounded_labels = np.round(self.labels['M'])
        label_indices = {label: [] for label in range(-3, 4)}
        for idx, label in enumerate(rounded_labels):
            label_indices[int(label)].append(idx)
        
        selected_indices = []
        for label in range(-3, 4):
            indices = label_indices[label]
            n_samples = int(len(indices) * self.sample_rate)
            selected_indices.extend(indices[:n_samples])

        selected_indices = sorted(selected_indices)
        self.text = self.text[selected_indices]
        self.vision = self.vision[selected_indices]
        self.audio = self.audio[selected_indices]
        self.raw_text = [self.raw_text[i] for i in selected_indices]
        self.ids = [self.ids[i] for i in selected_indices]
        self.labels['M'] = self.labels['M'][selected_indices]

        if not self.args.need_data_aligned:
            self.audio_lengths = [self.audio_lengths[i] for i in selected_indices]
            self.vision_lengths = [self.vision_lengths[i] for i in selected_indices]
        
        logger.info(f"After sampling {self.sample_rate*100}%, {self.mode} samples: {len(selected_indices)}")

    def __truncate(self):
        def do_truncate(modal_features, length):
            if length == modal_features.shape[1]:
                return modal_features
            truncated_feature = []
            padding = np.array([0 for i in range(modal_features.shape[2])])
            for instance in modal_features:
                for index in range(modal_features.shape[1]):
                    if((instance[index] == padding).all()):
                        if(index + length >= modal_features.shape[1]):
                            truncated_feature.append(instance[index:index+20])
                            break
                    else:                        
                        truncated_feature.append(instance[index:index+20])
                        break
            truncated_feature = np.array(truncated_feature)
            return truncated_feature
        
        text_length, audio_length, video_length = self.args.seq_lens
        self.vision = do_truncate(self.vision, video_length)
        self.text = do_truncate(self.text, text_length)
        self.audio = do_truncate(self.audio, audio_length)

    def __normalize(self):
        self.vision = np.transpose(self.vision, (1, 0, 2))
        self.audio = np.transpose(self.audio, (1, 0, 2))
        self.vision = np.mean(self.vision, axis=0, keepdims=True)
        self.audio = np.mean(self.audio, axis=0, keepdims=True)

        self.vision[self.vision != self.vision] = 0
        self.audio[self.audio != self.audio] = 0

        self.vision = np.transpose(self.vision, (1, 0, 2))
        self.audio = np.transpose(self.audio, (1, 0, 2))

    def __len__(self):
        return len(self.labels['M'])

    def get_seq_len(self):
        lens = []
        if self.use_text:
            if 'use_bert' in self.args and self.args.use_bert:
                lens.append(self.text.shape[2])
            else:
                lens.append(self.text.shape[1])
        if self.use_audio:
            lens.append(self.audio.shape[1])
        if self.use_video:
            lens.append(self.vision.shape[1])
        return tuple(lens)

    def get_feature_dim(self):
        dims = []
        if self.use_text:
            dims.append(self.text.shape[2])
        if self.use_audio:
            dims.append(self.audio.shape[2])
        if self.use_video:
            dims.append(self.vision.shape[2])
        return tuple(dims)

    def __getitem__(self, index):
        sample = {
            'index': index,
            'id': self.ids[index],
            'labels': {k: torch.Tensor(v[index].reshape(-1)) for k, v in self.labels.items()}
        }
        
        if self.use_text:
            #sample['raw_text'] = self.raw_text[index]
            sample['text'] = torch.Tensor(self.text[index])
        
        if self.use_audio:
            sample['audio'] = torch.Tensor(self.audio[index])
            if not self.args.need_data_aligned:
                sample['audio_lengths'] = self.audio_lengths[index]
        
        if self.use_video:
            sample['vision'] = torch.Tensor(self.vision[index])
            if not self.args.need_data_aligned:
                sample['vision_lengths'] = self.vision_lengths[index]
        return sample


def MMDataLoader(args, train_ratio=1.0, use_video=True, use_audio=True, use_text=True):
    datasets = {
        'train': MMDataset(args, mode='train', use_video=use_video, use_audio=use_audio, use_text=use_text, sample_rate=train_ratio),
        'valid': MMDataset(args, mode='valid', use_video=use_video, use_audio=use_audio, use_text=use_text),
        'test': MMDataset(args, mode='test', use_video=use_video, use_audio=use_audio, use_text=use_text)
    }

    # select modality combinations
    modalities = []
    if use_text:
        modalities.append('Text')
    if use_audio:
        modalities.append('Audio')
    if use_video:
        modalities.append('Video')
    print(f"Using modalities: {', '.join(modalities)}")

    if 'seq_lens' in args:
        args['seq_lens'] = datasets['train'].get_seq_len()

    dataLoader = {
        ds: DataLoader(datasets[ds],
                       batch_size=args.batch_size,
                       num_workers=args.num_workers,
                       shuffle=(ds == 'train'))
        for ds in datasets.keys()
    }
    
    return dataLoader
