import json
import typing
from pathlib import Path
import random

import torch
from torch.utils.data import Dataset

from util.globals import *

REMOTE_ROOT = f"{REMOTE_ROOT_URL}/data/dsets"

class CounterFactDataset(Dataset):
    def __init__(
        self,
        data_dir: str,
        multi: bool = False,
        size: typing.Optional[int] = None,
        *args,
        **kwargs,
    ):
        data_dir = Path(data_dir)
        cf_loc = data_dir / (
            "counterfact.json" if not multi else "multi_counterfact.json"
        )
        if not cf_loc.exists():
            remote_url = f"{REMOTE_ROOT}/{'multi_' if multi else ''}counterfact.json"
            print(f"{cf_loc} does not exist. Downloading from {remote_url}")
            data_dir.mkdir(exist_ok=True, parents=True)
            torch.hub.download_url_to_file(remote_url, cf_loc)

        cf_loc = '/path/cf_test.json'
        with open(cf_loc, "r") as f:
            self.data = json.load(f)
        
        # pset = ['POSITION_HELD', 'FIELD_OF_WORK', 'OFFICIAL_LANGUAGE', 'CAPITAL', 'COUNTRY_OF_CITIZENSHIP', 'OCCUPATION', 'COUNTRY', 'PLACE_OF_DEATH', 'EMPLOYER', 'CAPITAL_OF', 'PLACE_OF_BIRTH', 'CONTINENT']
        # self.data = [item for item in self.data if item['requested_rewrite']['relation'] not in pset]

        if size is not None:
            self.data = self.data[:size]
            # self.data = random.sample(self.data, size)
 
        print(f"Loaded dataset with {len(self)} elements")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, item):
        return self.data[item]


class MultiCounterFactDataset(CounterFactDataset):
    def __init__(
        self, data_dir: str, size: typing.Optional[int] = None, *args, **kwargs
    ):
        super().__init__(data_dir, *args, multi=True, size=size, **kwargs)


class RippleEditDataset(Dataset):
    def __init__(
        self,
        data_dir: str,
        multi: bool = False,
        size: typing.Optional[int] = None,
        *args,
        **kwargs,
    ):
        data_dir = Path(data_dir)

        cf_loc = '/path/case_study.json'
        with open(cf_loc, "r") as f:
            self.data = json.load(f)

        if size is not None:
            self.data = self.data[:size]
            # self.data = random.sample(self.data, size)
        # self.data = list(reversed(self.data))
 
        print(f"Loaded dataset with {len(self)} elements")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, item):
        return self.data[item]

class MRippleEditDataset(Dataset):
    def __init__(
        self,
        data_dir: str,
        multi: bool = False,
        size: typing.Optional[int] = None,
        *args,
        **kwargs,
    ):
        data_dir = Path(data_dir)

        cf_loc = '/path/new_collect_j_s.json'
        with open(cf_loc, "r") as f:
            self.data = json.load(f)

        if size is not None:
            self.data = self.data[:size]
            # self.data = random.sample(self.data, size)
        # self.data = list(reversed(self.data))
 
        print(f"Loaded dataset with {len(self)} elements")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, item):
        return self.data[item]

class RippleEditSubjectBatchDataset(Dataset):
    def __init__(
        self,
        data_dir: str,
        multi: bool = False,
        size: typing.Optional[int] = None,
        *args,
        **kwargs,
    ):
        data_dir = Path(data_dir)

        cf_loc = '/path/batch_subject.json'
        with open(cf_loc, "r") as f:
            self.data = json.load(f)

        if size is not None:
            self.data = self.data[:size]
            # self.data = random.sample(self.data, size)
        # self.data = list(reversed(self.data))
 
        print(f"Loaded dataset with {len(self)} elements")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, item):
        return self.data[item]
