import os
import pandas
import torch
import numpy as np
from torch.utils.data import Dataset

class Sequence:
    """ Base class for video datasets """

    def __init__(self, video_name, frame_list, bboxes, nlp, idx):
        self.video_name = video_name
        self.frames = frame_list
        self.bboxes = bboxes
        self.nlp = nlp
        self.idx = idx

        self.init_info = {'init_bbox':bboxes[0]}
        self.frame_info = {}
        self.construct_info()

    def construct_info(self):
        for idx, bbox in enumerate(self.bboxes):
            self.frame_info.setdefault(idx, {})
        self.frame_info[0] = self.init_info


class LaSOT_Dataset(Dataset):
    
    def __init__(self, cfg, split_file='ltr/data_specs/lasot_val_split.txt'):

        self.root = cfg.train.dataset.config.lasot.path
        self.split_file = split_file
        self.sequence_list = self._build_sequence_list()        

    def _build_sequence_list(self):
        file_path = os.path.join(self.split_file)
        sequence_list = pandas.read_csv(file_path, header=None).squeeze(1).values.tolist()

        return sequence_list

    def _get_sequence_path(self, seq_id):
        seq_name = self.sequence_list[seq_id]
        class_name = seq_name.split('-')[0]
        vid_id = seq_name.split('-')[1]

        return seq_name, os.path.join(self.root, class_name, class_name + '-' + vid_id)

    def _read_bb_anno(self, seq_path):
        bb_anno_file = os.path.join(seq_path, "groundtruth.txt")
        gt = pandas.read_csv(bb_anno_file, delimiter=',', header=None, dtype=np.float32, na_filter=False, low_memory=False).values

        return torch.tensor(gt)

    def _get_nlp(self, seq_path):
        nlp_text_file = os.path.join(seq_path, 'nlp.txt')
        with open(nlp_text_file, 'r') as f:
            nlp = f.readline().replace('\n', '')
        return nlp

    def __len__(self) -> int:
        return len(self.sequence_list)
        

    def __getitem__(self, idx: int) -> dict:
        seq_name, seq_path = self._get_sequence_path(idx)
        bbox = self._read_bb_anno(seq_path)
        nlp = self._get_nlp(seq_path)

        imgs = os.listdir(os.path.join(seq_path, 'img'))
        imgs = sorted(imgs)
        imgs = [os.path.join(os.path.join(seq_path, 'img', x)) for x in imgs]

        assert(len(bbox)==len(imgs))

        valid = (bbox[:, 2] > 0) & (bbox[:, 3] > 0)

        res_imgs = [imgs[i] for i in range(len(imgs))]# if valid[i]
        res_bbox = [bbox[i].numpy() for i in range(len(bbox))]# if valid[i]


        return Sequence(seq_name, res_imgs, res_bbox, nlp, idx)



class OTB99_Dataset(Dataset):
    
    def __init__(self, cfg, split_file='ltr/data_specs/otb99_val_split.txt'):

        self.root = cfg.train.dataset.config.otb99.path
        self.split_file = split_file
        self.sequence_list = self._build_sequence_list()
        

    def _build_sequence_list(self):
        file_path = os.path.join(self.split_file)
        sequence_list = pandas.read_csv(file_path, header=None).squeeze(1).values.tolist()

        return sequence_list

    def _get_sequence_path(self, seq_id):
        seq_name = self.sequence_list[seq_id]

        return seq_name, os.path.join(self.root, seq_name)

    def _read_bb_anno(self, seq_path):
        bb_anno_file = os.path.join(seq_path, "groundtruth_rect.txt")
        gt = pandas.read_csv(bb_anno_file, delimiter=',', header=None, dtype=np.float32, na_filter=False, low_memory=False).values

        return torch.tensor(gt)

    def _get_nlp(self, seq_path):
        seq_name = seq_path.split('/')[-1]
        OTB_query_train = self.root.replace('OTB_videos', 'OTB_query_test')
        nlp_text_file = os.path.join(OTB_query_train, seq_name+'.txt')
        with open(nlp_text_file, 'r') as f:
            nlp = f.readline().replace('\n', '')
        return nlp

    def __len__(self) -> int:
        return len(self.sequence_list)
        

    def __getitem__(self, idx: int) -> dict:
        seq_name, seq_path = self._get_sequence_path(idx)
        bbox = self._read_bb_anno(seq_path)
        nlp = self._get_nlp(seq_path)

        imgs = os.listdir(os.path.join(seq_path, 'img'))
        imgs = sorted(imgs)
        imgs = [os.path.join(os.path.join(seq_path, 'img', x)) for x in imgs]

        assert(len(bbox)==len(imgs))

        valid = (bbox[:, 2] > 0) & (bbox[:, 3] > 0)

        res_imgs = [imgs[i] for i in range(len(imgs))]# if valid[i]
        res_bbox = [bbox[i].numpy() for i in range(len(bbox))]# if valid[i]

        return Sequence(seq_name, res_imgs, res_bbox, nlp, idx)



class TNL2K_Dataset(Dataset):
    
    def __init__(self, cfg, split_file='ltr/data_specs/tnl2k_val_split.txt'):

        self.root = cfg.train.dataset.config.tnl2k.path.replace('train','test')
        self.split_file = split_file
        self.sequence_list = self._build_sequence_list()
        

    def _build_sequence_list(self):
        file_path = os.path.join(self.split_file)
        sequence_list = pandas.read_csv(file_path, header=None).squeeze(1).values.tolist()

        return sequence_list

    def _get_sequence_path(self, seq_id):
        seq_name = self.sequence_list[seq_id]
        real_seq_name = seq_name.split('/')[-1]

        return real_seq_name, os.path.join(self.root, seq_name)

    def _read_bb_anno(self, seq_path):
        bb_anno_file = os.path.join(seq_path, "groundtruth.txt")
        gt = pandas.read_csv(bb_anno_file, delimiter=',', header=None, dtype=np.float32, na_filter=False, low_memory=False).values

        return torch.tensor(gt)

    def _get_nlp(self, seq_path):
        nlp_text_file = os.path.join(seq_path, "language.txt")
        with open(nlp_text_file, 'r') as f:
            nlp = f.readline().replace('\n', '')
        return nlp

    def __len__(self) -> int:
        return len(self.sequence_list)
        

    def __getitem__(self, idx: int) -> dict:
        seq_name, seq_path = self._get_sequence_path(idx)
        bbox = self._read_bb_anno(seq_path)
        nlp = self._get_nlp(seq_path)

        imgs = os.listdir(os.path.join(seq_path, 'imgs'))
        imgs = sorted(imgs)
        imgs = [os.path.join(os.path.join(seq_path, 'imgs', x)) for x in imgs]

        assert(len(bbox)==len(imgs))

        # valid = (bbox[:, 2] > 0) & (bbox[:, 3] > 0)

        res_imgs = [imgs[i] for i in range(len(imgs))]     # if valid[i]
        res_bbox = [bbox[i].numpy() for i in range(len(bbox))]      # if valid[i]

        # if(len(res_bbox)!=len(bbox)):
        #     print("this file pass some image", seq_name)

        return Sequence(seq_name, res_imgs, res_bbox, nlp, idx)


def im_to_torch(img):
    """
    numpy image to pytorch tensor
    """
    img = np.transpose(img, (2, 0, 1))  # C*H*W
    img = torch.from_numpy(img).float()
    return img


