from logging import Logger
from torch.utils.data import DataLoader
import torch.distributed as dist
import torch
import numpy as np
from functools import partial
import random

import io
import os
import os.path as osp
import shutil
import warnings
from collections.abc import Mapping, Sequence
from mmcv.utils import Registry, build_from_cfg
from torch.utils.data import Dataset
import copy
import os.path as osp
import warnings
from abc import ABCMeta, abstractmethod
from collections import OrderedDict, defaultdict
import os.path as osp
import mmcv
import numpy as np
import torch
import tarfile
from .pipeline import *
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
from mmcv.parallel import collate
import pandas as pd
import clip

PIPELINES = Registry('pipeline')
img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False)


class BaseDataset(Dataset, metaclass=ABCMeta):
    def __init__(self,
                 ann_file,
                 pipeline,
                 repeat=1,
                 data_prefix=None,
                 test_mode=False,
                 multi_class=False,
                 num_classes=None,
                 start_index=1,
                 modality='RGB',
                 sample_by_class=False,
                 power=0,
                 dynamic_length=False,
                 neg_label_num=1000,
                 clip_num=2,
                 clip_len=8,
                 hard_neg=False,
                 hard_neg_num=20,
                 frame_step=4,
                 use_clip_encode=False,
                 augmentation_factor=100):
        super().__init__()
        self.use_tar_format = True if ".tar" in data_prefix else False
        data_prefix = data_prefix.replace(".tar", "")
        self.ann_file = ann_file
        self.repeat = repeat
        self.data_prefix = osp.realpath(
            data_prefix) if data_prefix is not None and osp.isdir(
                data_prefix) else data_prefix
        self.test_mode = test_mode
        self.multi_class = multi_class
        self.num_classes = num_classes
        self.start_index = start_index
        self.modality = modality
        self.sample_by_class = sample_by_class
        self.power = power
        self.dynamic_length = dynamic_length
        self.neg_label_num = neg_label_num
        self.clip_num = clip_num
        self.clip_len = clip_len
        self.hard_neg = hard_neg
        self.hard_neg_num = hard_neg_num
        self.frame_step = frame_step
        self.use_clip_encode = use_clip_encode
        self.augmentation_factor = augmentation_factor
        if use_clip_encode:
            self.clipmodel, _ = clip.load("ViT-L/14", device="cpu")

        assert not (self.multi_class and self.sample_by_class)

        self.pipeline = Compose(pipeline)
        self.video_infos = self.load_annotations()
        if self.sample_by_class:
            self.video_infos_by_class = self.parse_by_class()

            class_prob = []
            for _, samples in self.video_infos_by_class.items():
                class_prob.append(len(samples) / len(self.video_infos))
            class_prob = [x**self.power for x in class_prob]

            summ = sum(class_prob)
            class_prob = [x / summ for x in class_prob]

            self.class_prob = dict(zip(self.video_infos_by_class, class_prob))

    @abstractmethod
    def load_annotations(self):
        """Load the annotation according to ann_file into video_infos."""

    # json annotations already looks like video_infos, so for each dataset,
    # this func should be the same
    def load_json_annotations(self):
        """Load json annotation file to get video information."""
        video_infos = mmcv.load(self.ann_file)
        num_videos = len(video_infos)
        path_key = 'frame_dir' if 'frame_dir' in video_infos[0] else 'filename'
        for i in range(num_videos):
            path_value = video_infos[i][path_key]
            if self.data_prefix is not None:
                path_value = osp.join(self.data_prefix, path_value)
            video_infos[i][path_key] = path_value
            if self.multi_class:
                assert self.num_classes is not None
            else:
                assert len(video_infos[i]['label']) == 1
                video_infos[i]['label'] = video_infos[i]['label'][0]
        return video_infos

    def parse_by_class(self):
        video_infos_by_class = defaultdict(list)
        for item in self.video_infos:
            label = item['label']
            video_infos_by_class[label].append(item)
        return video_infos_by_class

    @staticmethod
    def label2array(num, label):
        arr = np.zeros(num, dtype=np.float32)
        arr[label] = 1.
        return arr

    @staticmethod
    def dump_results(results, out):
        """Dump data to json/yaml/pickle strings or files."""
        return mmcv.dump(results, out)

    def prepare_train_frames(self, idx, sampleclip=True, weightedsample=True, progressor=True, onlyinit=False):
        """Prepare the frames for training given the index."""
        results = copy.deepcopy(self.video_infos[idx])
        results['modality'] = self.modality
        results['start_index'] = self.start_index
        # results: {'filename': 'ROOT/67353.mp4', 'label': 87, 'tar': False, 'modality': 'RGB', 'start_index': 0}  

        # prepare tensor in getitem
        # If HVU, type(results['label']) is dict
        if self.multi_class and isinstance(results['label'], list):
            onehot = torch.zeros(self.num_classes)
            onehot[results['label']] = 1.
            results['label'] = onehot

        aug1 = self.pipeline(results)

        if self.neg_label_num == 0:
            if aug1['label'].shape[0] == 2:
                aug1['neg_flag'] = aug1['label'][1]
                aug1['label'] = aug1['label'][0]
            else:
                aug1['neg_flag'] = 0
                aug1['neg_labels'] = aug1['label']
            aug1['neg_labels'] = aug1['label']

        else:
            if self.hard_neg:
                pos_label = aug1['label'][0]
                hard_neg_labels = aug1['label'][1:1 + self.hard_neg_num]
                all_labels = torch.randperm(self.num_classes)
                all_neg_labels = all_labels[all_labels != pos_label]
                self.neg_label_num = min(self.neg_label_num, len(all_neg_labels - 1))
                neg_labels = all_neg_labels[:self.neg_label_num]
                neg_labels[:len(hard_neg_labels)] = hard_neg_labels
                neg_labels = torch.cat((torch.tensor([pos_label]), neg_labels), 0)
                aug1['neg_labels'] = neg_labels
                aug1['label'] = pos_label
            else:
                all_labels = torch.randperm(self.num_classes)
                all_neg_labels = all_labels[all_labels != aug1['label']]
                self.neg_label_num = min(self.neg_label_num, len(all_neg_labels - 1))
                neg_labels = all_neg_labels[:self.neg_label_num]
                neg_labels = torch.cat((torch.tensor([aug1['label']]), neg_labels), 0)
                aug1['neg_labels'] = neg_labels
        
        if self.use_clip_encode:
            with torch.no_grad():
                aug1['imgs'] = self.clipmodel.encode_image(aug1['imgs'])

        if progressor:
            step = 1
            frames = aug1['imgs']
            # print('frames shape:', frames.shape)
            frame_step = 1
            frames_unfold = frames.unfold(0, self.clip_len * frame_step, step).permute(0, 4, 1, 2, 3)
            frames_unfold = frames_unfold[:, ::frame_step]
            sample_id = torch.randperm(frames_unfold.shape[0])[:3]
            sample_id = sample_id.sort()[0]
            aug1['imgs'] = frames_unfold[sample_id]
            aug1['progress'] = (sample_id.float()) / (frames_unfold.shape[0])

        elif sampleclip:
            step = 1
            frames = aug1['imgs']
            # print('frames shape:', frames.shape)
            frame_step = 1
            frames_unfold = frames.unfold(0, self.clip_len * frame_step, step).permute(0, 4, 1, 2, 3)
            frames_unfold = frames_unfold[:, ::frame_step]
            # print('frames_unfold shape:', frames_unfold.shape)
            if not weightedsample:
                sample_id = torch.randperm(frames_unfold.shape[0])[:self.clip_num]
                sample_id = sample_id.sort()[0]
            else:
                # First, calculate possible distances from 1 to n-1
                n = frames_unfold.shape[0]
                possible_distances = torch.arange(1, n)
                
                # Convert distances to probabilities (closer distances get higher probability)
                distance_probs = 1.0 / (possible_distances.float() ** 2)
                distance_probs = distance_probs / distance_probs.sum()
                
                # Sample a distance according to probabilities
                distance = torch.multinomial(distance_probs, 1).item()
                
                # Given the distance, randomly sample the first index
                max_first_idx = n - distance - 1
                first_idx = torch.randint(0, max_first_idx + 1, (1,)).item()
                second_idx = first_idx + distance
                
                sample_id = torch.tensor([first_idx, second_idx])
            
            if onlyinit:
                sample_id[0] = 0
            # print('sample_id:', sample_id)
            aug1['imgs'] = frames_unfold[sample_id]
            
            aug1['progress'] = (sample_id.float()) / (frames_unfold.shape[0])

        if self.repeat > 1:
            aug2 = self.pipeline(results)
            ret = {"imgs": torch.cat((aug1['imgs'], aug2['imgs']), 0),
                    "label": aug1['label'].repeat(2),
            }
            return ret

        return aug1

    def prepare_test_frames(self, idx, sampleclip=True, weightedsample=True, progressor=True):
        """Prepare the frames for testing given the index."""
        results = copy.deepcopy(self.video_infos[idx])
        results['modality'] = self.modality
        results['start_index'] = self.start_index

        # prepare tensor in getitem
        # If HVU, type(results['label']) is dict
        if self.multi_class and isinstance(results['label'], list):
            onehot = torch.zeros(self.num_classes)
            onehot[results['label']] = 1.
            results['label'] = onehot

        aug1 = self.pipeline(results)
        if self.neg_label_num == 0:
            if aug1['label'].shape[0] > 1:
                aug1['label'] = aug1['label'][0]
            aug1['neg_labels'] = aug1['label']
        
        elif self.hard_neg and aug1['label'].shape[0] > 1:
            pos_label = aug1['label'][0]
            hard_neg_num = min(self.hard_neg_num, aug1['label'].shape[0] - 1)
            hard_neg_labels = aug1['label'][1:1 + hard_neg_num]
            all_labels = torch.randperm(self.num_classes)
            all_neg_labels = all_labels[all_labels != pos_label]
            self.neg_label_num = min(self.neg_label_num, len(all_neg_labels - 1))
            neg_labels = all_neg_labels[:self.neg_label_num]
            neg_labels[:len(hard_neg_labels)] = hard_neg_labels
            neg_labels = torch.cat((torch.tensor([pos_label]), neg_labels), 0)
            aug1['neg_labels'] = neg_labels
            aug1['label'] = pos_label
        else:
            all_labels = torch.randperm(self.num_classes)
            all_neg_labels = all_labels[all_labels != aug1['label']]
            self.neg_label_num = min(self.neg_label_num, len(all_neg_labels - 1))
            neg_labels = all_neg_labels[:self.neg_label_num]
            neg_labels = torch.cat((torch.tensor([aug1['label']]), neg_labels), 0)
            aug1['neg_labels'] = neg_labels

        if self.use_clip_encode:
            with torch.no_grad():
                aug1['imgs'] = self.clipmodel.encode_image(aug1['imgs'])
        
        if progressor:
            step = 1
            frames = aug1['imgs']
            # print('frames shape:', frames.shape)
            frame_step = 1
            frames_unfold = frames.unfold(0, self.clip_len * frame_step, step).permute(0, 4, 1, 2, 3)
            frames_unfold = frames_unfold[:, ::frame_step]
            sample_id = torch.randperm(frames_unfold.shape[0])[:3]
            sample_id = sample_id.sort()[0]
            sample_id[0] = 0
            sample_id[-1] = frames_unfold.shape[0] - 1
            aug1['imgs'] = frames_unfold[sample_id]
            aug1['progress'] = (sample_id.float()) / (frames_unfold.shape[0])

        elif sampleclip:
            step = 1
            frames = aug1['imgs']
            frame_step = self.frame_step
            frames_unfold = frames.unfold(0, self.clip_len * frame_step, step).permute(0, 4, 1, 2, 3)
            frames_unfold = frames_unfold[:, ::frame_step]
            if not weightedsample:
                # sample_id = torch.randperm(frames_unfold.shape[0])[:self.clip_num]
                # sample_id = sample_id.sort()[0]
                sample_id = torch.linspace(0, frames_unfold.shape[0] - 1, self.clip_num).long()
            else:
                # First, calculate possible distances from 1 to n-1
                n = frames_unfold.shape[0]
                possible_distances = torch.arange(1, n)
                
                # Convert distances to probabilities (closer distances get higher probability)
                distance_probs = 1.0 / (possible_distances.float())
                distance_probs = distance_probs / distance_probs.sum()
                
                # Sample a distance according to probabilities
                distance = torch.multinomial(distance_probs, 1).item()
                
                # Given the distance, randomly sample the first index
                max_first_idx = n - distance - 1
                first_idx = torch.randint(0, max_first_idx + 1, (1,)).item()
                second_idx = first_idx + distance
                
                sample_id = torch.tensor([first_idx, second_idx])
            aug1['imgs'] = frames_unfold[sample_id]
            aug1['progress'] = sample_id.float() / frames_unfold.shape[0]

        return aug1

    def __len__(self):
        """Get the size of the dataset."""
        if self.test_mode:
            return len(self.video_infos) * 10
        return len(self.video_infos) * self.augmentation_factor

    def __getitem__(self, idx):
        """Get the sample for either training or testing given index."""
        original_idx = idx % len(self.video_infos)
        if self.test_mode:
            return self.prepare_test_frames(original_idx)
        return self.prepare_train_frames(original_idx)

class VideoDataset(BaseDataset):
    def __init__(self, ann_file, pipeline, labels_file, start_index=0, **kwargs):
        super().__init__(ann_file, pipeline, start_index=start_index, **kwargs)
        self.labels_file = labels_file

    @property
    def classes(self):
        classes_all = pd.read_csv(self.labels_file)
        return classes_all.values.tolist()

    def load_annotations(self):
        """Load annotation file to get video information."""
        if self.ann_file.endswith('.json'):
            return self.load_json_annotations()

        video_infos = []
        with open(self.ann_file, 'r') as fin:
            for line in fin:
                line_split = line.strip().split()
                if self.multi_class:
                    assert self.num_classes is not None
                    filename, label = line_split[0], line_split[1:]
                    label = list(map(int, label))
                elif not self.hard_neg:
                    filename, label = line_split
                    label = int(label)
                else:
                    filename, label = line_split[0], line_split[1:]
                    label = list(map(int, label))
                if self.data_prefix is not None:
                    filename = osp.join(self.data_prefix, filename)
                video_infos.append(dict(filename=filename, label=label, tar=self.use_tar_format))
        return video_infos


class SubsetRandomSampler(torch.utils.data.Sampler):
    r"""Samples elements randomly from a given list of indices, without replacement.

    Arguments:
        indices (sequence): a sequence of indices
    """

    def __init__(self, indices):
        self.epoch = 0
        self.indices = indices

    def __iter__(self):
        return (self.indices[i] for i in torch.randperm(len(self.indices)))

    def __len__(self):
        return len(self.indices)

    def set_epoch(self, epoch):
        self.epoch = epoch


def mmcv_collate(batch, samples_per_gpu=1): 
    if not isinstance(batch, Sequence):
        raise TypeError(f'{batch.dtype} is not supported.')
    if isinstance(batch[0], Sequence):
        transposed = zip(*batch)
        return [collate(samples, samples_per_gpu) for samples in transposed]
    elif isinstance(batch[0], Mapping):
        return {
            key: mmcv_collate([d[key] for d in batch], samples_per_gpu)
            for key in batch[0]
        }
    else:
        return default_collate(batch)


def build_dataloader(logger, config):
    scale_resize = int(256 / 224 * config.DATA.INPUT_SIZE)
    # scale_resize = int(config.DATA.INPUT_SIZE)
    if config.DATA.USE_ORDER:
        data_frame_number = config.DATA.NUM_FRAMES
    else:
        data_frame_number = config.DATA.NUM_FRAMES

    train_pipeline = [
        dict(type='DecordInit'),
        dict(type='SampleFrames', clip_len=1, frame_interval=1, num_clips=data_frame_number),
        dict(type='DecordDecode'),
        dict(type='Resize', scale=(-1, scale_resize)),
        dict(
            type='MultiScaleCrop',
            input_size=config.DATA.INPUT_SIZE,
            scales=(1, 0.95, 0.875, 0.8),
            random_crop=True,
            max_wh_scale_gap=1),
        dict(type='Resize', scale=(config.DATA.INPUT_SIZE, config.DATA.INPUT_SIZE), keep_ratio=False),
        dict(type='CenterCrop', crop_size=config.DATA.INPUT_SIZE),
        # dict(type='Flip', flip_ratio=0.5),
        # dict(type='RandAugment', auto_augment='rand-mstd1-w0'),
        dict(type='ColorJitter', p=config.AUG.COLOR_JITTER),
        dict(type='GrayScale', p=config.AUG.GRAY_SCALE),
        dict(type='Normalize', **img_norm_cfg),
        dict(type='FormatShape', input_format='NCHW'),
        dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
        dict(type='ToTensor', keys=['imgs', 'label']),
    ]
    val_pipeline = [
        dict(type='DecordInit'),
        dict(type='SampleFrames', clip_len=1, frame_interval=1, num_clips=data_frame_number, test_mode=True),
        dict(type='DecordDecode'),
        dict(type='Resize', scale=(-1, scale_resize)),
        dict(type='CenterCrop', crop_size=config.DATA.INPUT_SIZE),
        dict(type='Normalize', **img_norm_cfg),
        dict(type='FormatShape', input_format='NCHW'),
        dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
        dict(type='ToTensor', keys=['imgs', 'label'])
    ]
        
    if config.DATA.TRAIN_FILE is not None:
        train_data = VideoDataset(ann_file=config.DATA.TRAIN_FILE, data_prefix=config.DATA.ROOT,
                                labels_file=config.DATA.LABEL_LIST, pipeline=train_pipeline,
                                num_classes=config.DATA.NUM_CLASSES, neg_label_num=config.DATA.NUM_NEGATIVE,
                                clip_num=config.DATA.NUM_CLIPS, clip_len=config.DATA.NUM_FRAMES_CLIP,
                                hard_neg=config.DATA.HARD_NEG, hard_neg_num=config.DATA.HARD_NEG_NUM)
        num_tasks = dist.get_world_size()
        global_rank = dist.get_rank()
        sampler_train = torch.utils.data.DistributedSampler(
            train_data, num_replicas=num_tasks, rank=global_rank, shuffle=True
        )
        train_loader = DataLoader(
            train_data, sampler=sampler_train,
            batch_size=config.TRAIN.BATCH_SIZE,
            num_workers=16,
            pin_memory=True,
            drop_last=True,
            collate_fn=partial(mmcv_collate, samples_per_gpu=config.TRAIN.BATCH_SIZE),
        )
    else:
        train_data = None
        train_loader = None
    
    if config.TEST.NUM_CROP == 3:
        val_pipeline[3] = dict(type='Resize', scale=(-1, config.DATA.INPUT_SIZE))
        val_pipeline[4] = dict(type='ThreeCrop', crop_size=config.DATA.INPUT_SIZE)
    if config.TEST.NUM_CLIP > 1:
        val_pipeline[1] = dict(type='SampleFrames', clip_len=1, frame_interval=1, num_clips=data_frame_number, multiview=config.TEST.NUM_CLIP)
    
    val_data = VideoDataset(ann_file=config.DATA.VAL_FILE, data_prefix=config.DATA.ROOT, 
    labels_file=config.DATA.LABEL_LIST, pipeline=val_pipeline, num_classes=config.DATA.NUM_CLASSES,
    neg_label_num=config.DATA.NUM_NEGATIVE, clip_num=config.DATA.NUM_CLIPS_VAL, clip_len=config.DATA.NUM_FRAMES_CLIP,
    hard_neg=config.DATA.HARD_NEG, hard_neg_num=config.DATA.HARD_NEG_NUM, test_mode=True,
    frame_step=config.DATA.CLIP_FRAME_STEP)
    indices = np.arange(dist.get_rank(), len(val_data), dist.get_world_size())
    sampler_val = SubsetRandomSampler(indices)
    val_loader = DataLoader(
        val_data, sampler=sampler_val,
        batch_size=config.TEST.BATCH_SIZE,
        num_workers=16,
        pin_memory=True,
        drop_last=True,
        collate_fn=partial(mmcv_collate, samples_per_gpu=config.TEST.BATCH_SIZE),
    )

    if config.DATA.VAL_FILE_OUTDOMAIN is not None:
        val_data_outdomain = VideoDataset(ann_file=config.DATA.VAL_FILE_OUTDOMAIN, data_prefix=config.DATA.ROOT_OUTDOMAIN, 
        labels_file=config.DATA.LABEL_LIST_VAL, pipeline=val_pipeline, num_classes=config.DATA.NUM_CLASSES_VAL,
        neg_label_num=config.DATA.NUM_NEGATIVE, clip_num=config.DATA.NUM_CLIPS_VAL, clip_len=config.DATA.NUM_FRAMES_CLIP,
        frame_step=config.DATA.CLIP_FRAME_STEP, test_mode=True,)
        indices_outdomain = np.arange(dist.get_rank(), len(val_data_outdomain), dist.get_world_size())
        sampler_val_outdomain = SubsetRandomSampler(indices_outdomain)
        val_loader_outdomain = DataLoader(
            val_data_outdomain, sampler=sampler_val_outdomain,
            batch_size=config.TEST.BATCH_SIZE,
            num_workers=16,
            pin_memory=True,
            drop_last=True,
            collate_fn=partial(mmcv_collate, samples_per_gpu=config.TEST.BATCH_SIZE),
        )
    else:
        val_loader_outdomain = None
        val_data_outdomain = None

    if config.DATA.TRAIN_FILE_COMMON is not None:
        common_data = VideoDataset(ann_file=config.DATA.TRAIN_FILE_COMMON, data_prefix=config.DATA.ROOT,
                              labels_file=config.DATA.LABEL_LIST, pipeline=val_pipeline,
                              num_classes=config.DATA.NUM_CLASSES, neg_label_num=config.DATA.NUM_NEGATIVE,
                              clip_num=config.DATA.NUM_CLIPS_VAL, clip_len=config.DATA.NUM_FRAMES_CLIP,
                              frame_step=config.DATA.CLIP_FRAME_STEP, test_mode=True,)
        rare_data = VideoDataset(ann_file=config.DATA.TRAIN_FILE_RARE, data_prefix=config.DATA.ROOT,
                              labels_file=config.DATA.LABEL_LIST, pipeline=val_pipeline,
                              num_classes=config.DATA.NUM_CLASSES, neg_label_num=config.DATA.NUM_NEGATIVE,
                              clip_num=config.DATA.NUM_CLIPS_VAL, clip_len=config.DATA.NUM_FRAMES_CLIP,
                              frame_step=config.DATA.CLIP_FRAME_STEP, test_mode=True,)
        indices_common = np.arange(dist.get_rank(), len(common_data), dist.get_world_size())
        indices_rare = np.arange(dist.get_rank(), len(rare_data), dist.get_world_size())
        sampler_common = SubsetRandomSampler(indices_common)
        sampler_rare = SubsetRandomSampler(indices_rare)
        val_loader_common = DataLoader(
            common_data, sampler=sampler_common,
            batch_size=config.TEST.BATCH_SIZE,
            num_workers=16,
            pin_memory=True,
            drop_last=True,
            collate_fn=partial(mmcv_collate, samples_per_gpu=config.TEST.BATCH_SIZE),
        )
        val_loader_rare = DataLoader(
            rare_data, sampler=sampler_rare,
            batch_size=config.TEST.BATCH_SIZE,
            num_workers=16,
            pin_memory=True,
            drop_last=True,
            collate_fn=partial(mmcv_collate, samples_per_gpu=config.TEST.BATCH_SIZE),
        )
    
    else:
        val_loader_common = None
        val_loader_rare = None
        common_data = None
        rare_data = None

    return train_data, val_data, val_data_outdomain, common_data, rare_data, train_loader, val_loader, val_loader_outdomain, val_loader_common, val_loader_rare