import pickle
import random

import pandas as pd
import torch
import torchvision.transforms as transforms
from torch.utils.data.dataset import Dataset

import boto3
from diffusers.utils import logging
import sys

from utils.common_utils import read_video_to_tensor

logger = logging.get_logger(__name__)


class OpenVidDataset(Dataset):
    def __init__(
        self,
        path_to_csv,
        video_root="openvid/unzip",
        sample_fps=8,
        sample_frames=16,
        sample_size=[320, 512],
    ):
        self.video_root = video_root
        self.s3_client = boto3.client("s3")
        self.bucket = "BUCKET_NAME"

        logger.info(f"loading annotations from {path_to_csv} ...")
        self.video_df = pd.read_csv(path_to_csv)
        self.length = len(self.video_df)
        logger.info(f"data scale: {self.length}")

        self.sample_fps = sample_fps
        self.sample_frames = sample_frames

        sample_size = (
            tuple(sample_size)
            if not isinstance(sample_size, int)
            else (sample_size, sample_size)
        )
        self.pixel_transforms = transforms.Compose(
            [
                transforms.Resize(sample_size),
                transforms.CenterCrop(sample_size),
                transforms.Normalize(
                    mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True
                ),
            ]
        )

    def get_video_text_pair(self, idx):
        video_dict = self.video_df.iloc[idx].to_dict()
        relpath, text = video_dict["relpath"], video_dict["text"]
        video_dir = f"{self.video_root}/{relpath.replace('.pkl', '.mp4')}"

        data_body = self.s3_client.get_object(Bucket=self.bucket, Key=video_dir).get(
            "Body"
        )
        uniform_sampling = True if "ChronoMagic" in self.video_root else False
        pixel_values = read_video_to_tensor(
            data_body,
            self.sample_fps,
            self.sample_frames,
            uniform_sampling=uniform_sampling,
        )
        return pixel_values, text, relpath

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        while True:
            try:
                pixel_values, text, relpath = self.get_video_text_pair(idx)
                break
            except Exception as e:
                idx = random.randint(0, self.length - 1)

        pixel_values = self.pixel_transforms(pixel_values)
        sample = dict(mp4=pixel_values, txt=text, relpath=relpath)
        return sample


class OpenVidLatentDataset(Dataset):
    def __init__(
        self,
        path_to_csv,
        latent_root="openvid/latents_revised",
    ):
        self.s3_resource = boto3.resource("s3")
        self.bucket = "BUCKET_NAME"
        self.latent_root = latent_root

        logger.info(f"loading annotations from {path_to_csv} ...")
        self.latent_df = pd.read_csv(path_to_csv)
        self.length = len(self.latent_df)
        logger.info(f"data scale: {self.length}")

    def get_latent_text_pair(self, idx):
        latent_dict = self.latent_df.iloc[idx].to_dict()
        relpath, text = latent_dict["relpath"], latent_dict["text"]
        if latent_dict.get("latent_root", None) is not None:
            latent_dir = f"{latent_dict['latent_root']}/{relpath}"
        else:
            latent_dir = f"{self.latent_root}/{relpath}"

        if "use_motion_guide" in latent_dict:
            use_motion_guide = bool(latent_dict["use_motion_guide"])
        else:
            use_motion_guide = True

        if "short_text" in latent_dict:
            short_text = latent_dict["short_text"]
        else:
            short_text = ""

        if str(short_text) == "nan":
            short_text = ""

        latent_dict = pickle.loads(
            self.s3_resource.Bucket(self.bucket).Object(latent_dir).get()["Body"].read()
        )
        if "webvid" in latent_dir:
            text = latent_dict.pop("text")
            short_text = text
        elif "text" in latent_dict:
            assert text == latent_dict.pop("text")
        return latent_dict, text, short_text, use_motion_guide

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        while True:
            try:
                latent_dict, text, short_text, use_motion_guide = (
                    self.get_latent_text_pair(idx)
                )
                for k in latent_dict.keys():
                    if isinstance(latent_dict[k], torch.Tensor):
                        latent_dict[k] = latent_dict[k].detach().cpu()
                sample = dict(
                    txt=text, short_txt=short_text, use_motion_guide=use_motion_guide
                )
                sample.update(latent_dict)
                break
            except Exception as e:
                idx = random.randint(0, self.length - 1)
        return sample


class JointLatentDataset(Dataset):
    def __init__(
        self,
        path_to_csv,
        path_to_webvid_csv=None,
        latent_root="openvid/latents_revised",
        webvid_latent_root="data/webvid_latents",
    ):
        self.s3_resource = boto3.resource("s3")
        self.bucket = "BUCKET_NAME"
        self.latent_root = latent_root

        logger.info(f"loading annotations from {path_to_csv} ...")
        self.latent_df = pd.read_csv(path_to_csv)
        self.length = len(self.latent_df)
        logger.info(f"data scale: {self.length}")

        self.webvid_df = None
        self.webvid_latent_root = webvid_latent_root
        if path_to_webvid_csv is not None:
            self.webvid_df = pd.read_csv(path_to_webvid_csv)
            self.webvid_length = len(self.webvid_df)
            logger.info(f"webvid data scale: {self.webvid_length}")

    def get_latent_text_pair(self, idx):
        latent_dict = self.latent_df.iloc[idx].to_dict()
        relpath, text = latent_dict["relpath"], latent_dict["text"]
        if latent_dict.get("latent_root", None) is not None:
            latent_dir = f"{latent_dict['latent_root']}/{relpath}"
        else:
            latent_dir = f"{self.latent_root}/{relpath}"

        if "use_motion_guide" in latent_dict:
            use_motion_guide = bool(latent_dict["use_motion_guide"])
        else:
            use_motion_guide = True

        latent_dict = pickle.loads(
            self.s3_resource.Bucket(self.bucket).Object(latent_dir).get()["Body"].read()
        )
        if "webvid" in latent_dir:
            text = latent_dict.pop("text")
        return latent_dict, text, use_motion_guide

    def get_webvid_latent_text_pair(self, idx):
        latent_dict = self.webvid_df.iloc[idx].to_dict()
        relpath = latent_dict["relpath"]
        if latent_dict.get("latent_root", None) is not None:
            latent_dir = f"{latent_dict['latent_root']}/{relpath}"
        else:
            latent_dir = f"{self.webvid_latent_root}/{relpath}"

        latent_dict = pickle.loads(
            self.s3_resource.Bucket(self.bucket).Object(latent_dir).get()["Body"].read()
        )
        webvid_latent_dict = {
            f"webvid_{k}": latent_dict[k] for k in latent_dict.keys()
        }
        return webvid_latent_dict

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        latent_dict, text, use_motion_guide = self.get_latent_text_pair(idx)
        for k in latent_dict.keys():
            if isinstance(latent_dict[k], torch.Tensor):
                latent_dict[k] = latent_dict[k].detach().cpu()
        sample = dict(txt=text, use_motion_guide=use_motion_guide)
        sample.update(latent_dict)
        if self.webvid_df is not None:
            webvid_idx = idx % self.webvid_length
            webvid_latent_dict = self.get_webvid_latent_text_pair(webvid_idx)
            for k in webvid_latent_dict.keys():
                if isinstance(webvid_latent_dict[k], torch.Tensor):
                    webvid_latent_dict[k] = webvid_latent_dict[k].detach().cpu()
            sample.update(webvid_latent_dict)
        return sample


if __name__ == "__main__":
    import torchvision
    from torch.utils.data import DataLoader


    random_indx = list(range(140000, 140010))
    dataset = OpenVidLatentDataset("/home/ubuntu/data/vidgen_144k_uniform_motion.csv")
    data_loader = DataLoader(dataset, batch_size=4, shuffle=True)

    for i, sample in enumerate(data_loader):
        print(sample["txt"])
        print(sample["short_txt"])
        print(sample["use_motion_guide"])
        print(sample["index"])
        for k, v in sample.items():
            if isinstance(v, torch.Tensor):
                print(k, v.shape)
        break
