import os
import torch
import pandas as pd
from PIL import Image
import numpy as np
from wilds.datasets.wilds_dataset import WILDSDataset
from wilds.common.grouper import CombinatorialGrouper
from wilds.common.metrics.all_metrics import Accuracy
from pathlib import Path

import tarfile
from zipfile import ZipFile
import logging
import gdown
import pandas as pd
import numpy as np
from pathlib import Path

def download_and_extract(url, dst, remove=True):
    gdown.download(url, dst, quiet=False)

    if dst.endswith(".tar.gz"):
        tar = tarfile.open(dst, "r:gz")
        tar.extractall(os.path.dirname(dst))
        tar.close()

    if dst.endswith(".tar"):
        tar = tarfile.open(dst, "r:")
        tar.extractall(os.path.dirname(dst))
        tar.close()

    if dst.endswith(".zip"):
        zf = ZipFile(dst, "r")
        zf.extractall(os.path.dirname(dst))
        zf.close()

    if remove:
        os.remove(dst)


def download_imagenetbg(data_path):
    logging.info("Downloading ImageNet Backgrounds Challenge...")
    bg_dir = os.path.join(data_path, "backgrounds_challenge")
    os.makedirs(bg_dir, exist_ok=True)
    download_and_extract(
        "https://github.com/MadryLab/backgrounds_challenge/releases/download/data/backgrounds_challenge_data.tar.gz",
        os.path.join(bg_dir, "backgrounds_challenge_data.tar.gz"),
        remove=True
    )
    download_and_extract(
        "https://www.dropbox.com/s/0vv2qsc4ywb4z5v/original.tar.gz?dl=1",
        os.path.join(bg_dir, "original.tar.gz"),
        remove=True
    )


class ImageNetBGDataset(WILDSDataset):
    _dataset_name = 'ImageNetBG'

    def __init__(self, root_dir='', 
            split_scheme='official', test_pct=0.2, val_pct=0.1, data_seed = None):
        self._data_dir = Path(self.initialize_data_dir(root_dir))

        if data_seed is not None:
            state = np.random.get_state()
            np.random.seed(data_seed)

        all_data = []
        dirs = {
            'train': 'original/train',
            'val': 'original/val',
            'test': 'bg_challenge/original/val',
            'mixed_rand': 'bg_challenge/mixed_rand/val',
            'only_fg': 'bg_challenge/only_fg/val',
            'no_fg': 'bg_challenge/no_fg/val',
            # 'test': 'bg_challenge_test/original/val',
            # 'mixed_rand': 'bg_challenge_test/mixed_rand/val',
            # 'only_fg': 'bg_challenge_test/only_fg/val',
            # 'no_fg': 'bg_challenge_test/no_fg/val',
            'paintings_bg': 'bg_challenge/paintings_bg/val',
        }
        groups = ['og', 'fg', 'mr', 'bg']
        group_dict = {'train': 'og', 'val': 'og', 'test': 'og', 
                        'mixed_rand': 'mr', 'only_fg': 'fg',
                        'no_fg': 'bg', 'paintings_bg': 'mr'}

        self._split_dict = {i:c for c, i in enumerate(list(dirs.keys()))}
        self._split_names = {i:i for c, i in enumerate(list(dirs.keys()))}
        classes = {
            0: 'dog',
            1: 'bird',
            2: 'wheeled vehicle',
            3: 'reptile',
            4: 'carnivore',
            5: 'insect',
            6: 'musical instrument',
            7: 'primate',
            8: 'fish'
        }

        for dir in dirs:
            for label in classes:
                label_folder = f'0{label}_{classes[label]}'
                folder_path = self._data_dir/dirs[dir]/label_folder
                for img_path in folder_path.glob('*.JPEG'):
                    all_data.append({
                        # 'split': dir,
                        'split': self._split_dict[dir],
                        'path': img_path,
                        'y': label,
                        'g': groups.index(group_dict[dir]),
                    })
        
        df = pd.DataFrame(all_data)
        self._meta_df = df

        self._y_array = torch.LongTensor(df['y'].values)
        self._y_size = 1
        self._n_classes = len(classes)

        self._metadata_array = torch.stack(
            (torch.LongTensor(df['g'].values), self._y_array),
            dim=1
        )
        self._metadata_fields = ['g', 'y']
        self._metadata_map = {
            'g': groups,
            'y': [classes[i] for i in range(len(classes))]
        }

        self._original_resolution = (224, 224) # as images are different sizes, we resize everything to 224 x 224

        self._eval_grouper = CombinatorialGrouper(
            dataset=self,
            groupby_fields=(['g']))
        
        self._split_scheme = split_scheme

        self._split_array = torch.LongTensor(df['split'].values)
        # np.zeros((len(df), 1))
        # for i, j in self._split_dict.items():
            # self._split_array[(df['split'] == i).values] = j

        self.df = df

        if data_seed is not None:
            np.random.set_state(state)

        super().__init__(self._data_dir, split_scheme)


    def get_input(self, idx):
       # Note: idx and filenames are off by one.
       img_filename =self.df.iloc[idx]['path']
       x = Image.open(img_filename).convert('RGB').resize((self._original_resolution))
       return x

    def eval(self, y_pred, y_true, metadata, prediction_fn=None):
        """
        Computes all evaluation metrics.
        Args:
            - y_pred (Tensor): Predictions from a model. By default, they are predicted labels (LongTensor).
                               But they can also be other model outputs such that prediction_fn(y_pred)
                               are predicted labels.
            - y_true (LongTensor): Ground-truth labels
            - metadata (Tensor): Metadata
            - prediction_fn (function): A function that turns y_pred into predicted labels 
        Output:
            - results (dictionary): Dictionary of evaluation metrics
            - results_str (str): String summarizing the evaluation metrics
        """
        metric = Accuracy(prediction_fn=prediction_fn)
        return self.standard_group_eval(
            metric,
            self._eval_grouper,
            y_pred, y_true, metadata)
