import os
import torch
import pathlib
import pandas as pd
import numpy as np
from PIL import Image
from matplotlib import cm
from wilds.datasets.wilds_dataset import WILDSDataset
from wilds.common.grouper import CombinatorialGrouper
from wilds.common.metrics.all_metrics import Accuracy

ENV_DICT = {
    0: 'in_distrib_maj',
    1: 'in_distrib_min',
    2: 'biasing_places',
    3: 'ood'
}

class COCOonPlacesDataset(WILDSDataset):
    """
    Wilds version of
    """
    _dataset_name = 'coco_on_places'
    _versions_dict = {
        '1.0': { 'csv': 'coco_places_dataset.csv'}
    }

    def __init__(
        self,
        version=None,
        root_dir='data',
        download=False,
        split_scheme='official',
        env_file_path=None,
        get_img_idx=False,
    ):
        self._version = version
        self._data_dir = root_dir
        if not os.path.exists(self.data_dir):
            raise ValueError(
                f'{self.data_dir} does not exist yet. Please generate the dataset first.')

        # Read in metadata
        # Note: metadata_df is one-indexed.
        metadata_df = pd.read_csv(os.path.join(self._data_dir, self._versions_dict['1.0']['csv']))

        # Get the y values
        self._y_array = torch.LongTensor(metadata_df['y'].values)
        self._y_size = 1
        self._n_classes = 9

        self._metadata_array = torch.stack(
            (torch.LongTensor(metadata_df['group'].values), self._y_array),
            dim=1
        )
        self._metadata_fields = ['background', 'y']
        self._metadata_map = {
            'background': ['bias', 'non-bias', 'sys shift', 'ood'], # Padding for str formatting
            'y': [
                'boat',
                'airplane',
                'truck',
                'dog',
                'zebra',
                'horse',
                'bird',
                'train',
                'bus'
            ]
        }
        self._original_resolution = (64, 64)

        # Extract filenames
        self._input_array = metadata_df['img_filename'].values
        self._original_resolution = (224, 224)

        # Extract splits
        self._split_array = metadata_df['data_type'].values

        self._eval_grouper = CombinatorialGrouper(
            dataset=self,
            groupby_fields=(['background']))

        super().__init__(root_dir, download, split_scheme, env_file_path, get_img_idx, remove_check_init=True)

    def get_input(self, idx):
       """
       Returns x for a given idx.
       """
       img_filename = os.path.join(
           self.data_dir,
           self._input_array[idx])
       x = np.load(img_filename).transpose((1,2,0))
       im = Image.fromarray(np.uint8(x*255.)).convert('RGB')
       return im

    def eval(self, y_pred, y_true, metadata, prediction_fn=None):
        """
        Computes all evaluation metrics.
        Args:
            - y_pred (Tensor): Predictions from a model. By default, they are predicted labels (LongTensor).
                               But they can also be other model outputs such that prediction_fn(y_pred)
                               are predicted labels.
            - y_true (LongTensor): Ground-truth labels
            - metadata (Tensor): Metadata
            - prediction_fn (function): A function that turns y_pred into predicted labels
        Output:
            - results (dictionary): Dictionary of evaluation metrics
            - results_str (str): String summarizing the evaluation metrics
        """
        metric = Accuracy(prediction_fn=prediction_fn)

        results, results_str = self.standard_group_eval(
            metric,
            self._eval_grouper,
            y_pred, y_true, metadata)

        # Compute in-distribution accuracy on evaluation
        results['ind_acc_avg'] = (
            (results['acc_background:bias'] * results['count_background:bias']
            + results['acc_background:non-bias'] * results['count_background:non-bias']) /
            (results['count_background:bias'] + results['count_background:non-bias']))
        del results['acc_avg']
        results_str = f"In-distribution acc: {results['ind_acc_avg']:.3f}\n" + '\n'.join(results_str.split('\n')[1:])

        return results, results_str
