import os
import torch
import pandas as pd
from PIL import Image
import numpy as np
from wilds.datasets.wilds_dataset import WILDSDataset
from wilds.common.grouper import CombinatorialGrouper
from wilds.common.metrics.all_metrics import Accuracy
from pathlib import Path
import json
from sklearn.model_selection import train_test_split

class NICOppDataset(WILDSDataset):
    _dataset_name = 'nicopp'

    def __init__(self, root_dir='', 
            split_scheme='official', test_pct = 0.2, val_pct = 0.1, data_seed = None):
        self._data_dir = Path(self.initialize_data_dir(root_dir))

        if data_seed is not None:
            state = np.random.get_state()
            np.random.seed(data_seed)

        attributes = ['autumn', 'dim', 'grass', 'outdoor', 'rock', 'water']   # 6 attributes, 60 labels
        meta = json.load(open(self._data_dir/'dg_label_id_mapping.json', 'r'))
        inv_meta = {j:i for i, j in meta.items()}

        all_data = []
        for c, attr in enumerate(attributes):
            for label in meta:
                folder_path = self._data_dir/'public_dg_0416'/'train'/attr/label
                y = meta[label]
                for img_path in Path(folder_path).glob('*.jpg'):
                    all_data.append({
                        'path': img_path,
                        'y': y,
                        'a': c,
                        'y_name': label,
                        'a_name': attr
                    })

        df = pd.DataFrame(all_data)

        df['g'] =  + df['a_name'] + '_' + df['y_name']
        self.orig_df = df.copy()

        g_counts = df.g.value_counts()
        gs_to_take = g_counts[g_counts >= 150]
        df = df[df.g.isin(gs_to_take.index)]

        self.g_mapping = {i:c for c, i in enumerate(gs_to_take.index)}

        # Get the y values
        self._y_array = torch.LongTensor(df['y'].values)
        self._y_size = 1
        self._n_classes = len(df['y'].unique())

        # Get metadata
        self._metadata_array = torch.stack(
            (torch.LongTensor(df['a'].values), self._y_array, torch.LongTensor(df['g'].map(self.g_mapping).values)),
            dim=1
        )
        self._metadata_fields = ['attr', 'y', 'g']
        self._metadata_map = {
            'attr': attributes, 
            'y': [inv_meta[i] for i in range(self._n_classes)],
            'g': gs_to_take.index
        }

        self._original_resolution = (224, 224)
        
        self._eval_grouper = CombinatorialGrouper(
            dataset=self,
            groupby_fields=(['g']))
        
        self._split_scheme = split_scheme

        df = df.reset_index(drop = True)

        idx_train_val, idx_test = train_test_split(df.index, test_size = test_pct, stratify = df.g, random_state = data_seed)
        idx_train, idx_val = train_test_split(idx_train_val, test_size = val_pct/(1 - test_pct), stratify = df.loc[idx_train_val, 'g'], random_state = data_seed)        

        self._split_array = np.zeros((len(df), 1))
        self._split_array[idx_val] = 1
        self._split_array[idx_test] = 2

        self.df = df

        if data_seed is not None:
            np.random.set_state(state)

        super().__init__(self._data_dir, split_scheme)

    def get_input(self, idx):
       # Note: idx and filenames are off by one.
       img_filename =self.df.iloc[idx]['path']
       x = Image.open(img_filename).convert('RGB').resize((self._original_resolution))
       return x

    def eval(self, y_pred, y_true, metadata, prediction_fn=None):
        """
        Computes all evaluation metrics.
        Args:
            - y_pred (Tensor): Predictions from a model. By default, they are predicted labels (LongTensor).
                               But they can also be other model outputs such that prediction_fn(y_pred)
                               are predicted labels.
            - y_true (LongTensor): Ground-truth labels
            - metadata (Tensor): Metadata
            - prediction_fn (function): A function that turns y_pred into predicted labels 
        Output:
            - results (dictionary): Dictionary of evaluation metrics
            - results_str (str): String summarizing the evaluation metrics
        """
        metric = Accuracy(prediction_fn=prediction_fn)
        return self.standard_group_eval(
            metric,
            self._eval_grouper,
            y_pred, y_true, metadata)
