import torch
import os
import numpy as np
import random
from torch.utils.data import Dataset
import pdb


MAX_LENGTH = 13000


def collate(batch):
    features_list = []
    labels_list = []
    slice_list = []
    for item in batch:
        features_list.append(item[0])
        labels_list.append(item[1])
        slice_list.append(item[2])
    return features_list, labels_list, slice_list


def align_batch(batch, num_classes, causal):
    feature_list, label_list, slice_list = batch

    # causal intervention
    if causal:
        f_all = []
        l_all = []

        for f, l, slc in zip(feature_list, label_list, slice_list):
            for item in slc:
                f_all.append(f[:, item[0]:item[1]])
                l_all.append(l[item[0]:item[1]])

        avg_slc = len(l_all) // len(label_list)
        for _ in range(len(label_list)):
            aug_f = []
            aug_l = []
            order = random.choices(range(len(l_all)), k=avg_slc)
            for sn in order:
                aug_f.append(f_all[sn])
                aug_l.append(l_all[sn])
            aug_f = np.concatenate(aug_f, axis=1)
            aug_l = np.concatenate(aug_l, axis=0)
            i = 1
            while aug_l.shape[0] > MAX_LENGTH:
                print("longer than MAX_LENGTH")
                aug_f = aug_f[:, :-l_all[order[-i]].shape[0]]
                aug_l = aug_l[:-l_all[order[-i]].shape[0]]
                i += 1
            feature_list.append(aug_f)
            label_list.append(aug_l)

    length_of_sequences = list(map(len, label_list))
    batch_input_tensor = torch.zeros(len(feature_list), feature_list[0].shape[0], max(length_of_sequences),
                                     dtype=torch.float)
    batch_target_tensor = torch.ones(len(feature_list), max(length_of_sequences), dtype=torch.long) * (-100)
    mask = torch.zeros(len(feature_list), num_classes, max(length_of_sequences), dtype=torch.float)
    for i in range(len(feature_list)):
        batch_input_tensor[i, :, :feature_list[i].shape[1]] = torch.from_numpy(feature_list[i])
        batch_target_tensor[i, :label_list[i].shape[0]] = torch.from_numpy(label_list[i])
        mask[i, :, :label_list[i].shape[0]] = torch.ones(num_classes, label_list[i].shape[0])
    return batch_input_tensor, batch_target_tensor, mask


class FeaturesDataset(Dataset):
    def __init__(self, path, train=True):
        self.path = path
        if train:
            self.txt_path = 'train.txt'
        else:
            self.txt_path = 'test.txt'
        self.video_list = []
        self.feature_list = []
        self.label_list = []
        self.slice_list = []
        self.read_txt()
        self.get_data()

    def read_txt(self):
        with open(os.path.join(self.path, self.txt_path), "r") as f:
            for line in f.readlines():
                line = line.strip('\n')
                self.video_list.append(line)

    def get_data(self):
        for video in self.video_list:
            feature = np.load(os.path.join(self.path, 'features', video + '.npy'))
            label = np.load(os.path.join(self.path, 'labels', video + '.npy'))
            self.feature_list.append(feature.transpose())
            self.label_list.append(label)
            slc = []
            i = 0
            j = 1
            while j < len(label):
                if label[i] != label[j]:
                    slc.append([i, j])
                    i = j
                j += 1
            slc.append([i, j])
            self.slice_list.append(slc)

    def __len__(self):
        return len(self.label_list)

    def __getitem__(self, index):
        batch = self.feature_list[index], self.label_list[index], self.slice_list[index]
        return batch
