import os
import torch
import pandas as pd
from PIL import Image
import numpy as np
from wilds.datasets.wilds_dataset import WILDSDataset, dataset_root
from wilds.common.grouper import CombinatorialGrouper
from wilds.common.metrics.all_metrics import Accuracy
from pathlib import Path
from sklearn.model_selection import train_test_split

import tarfile
from zipfile import ZipFile
import logging
import gdown
import pandas as pd
import numpy as np
from pathlib import Path
import pickle
from torch.utils.data import Dataset
import glob
import numpy as np
import torch
import scipy
from scipy import ndimage
from tqdm import tqdm
from torchvision import models, transforms


class HardImageNet(WILDSDataset):
    _dataset_name = 'hard_imagenet'

    SPLITS = [        
        'train', # their dataset only contains train and val. Here, we split train into train and val; and val becomes test.
        'val',
        'test',
        ## balanced training sets for finetuning 
        'spur_balanced_tr',
        'spur_balanced_va',
        ## augmented test sets
        'tile_test',
        'replace_with_gray_test',
        'replace_bbox_with_gray_test'
    ]

    def __init__(self, root_dir='', 
            split_scheme='official', test_pct=0.2, val_pct=0.2, data_seed = None):
        self._data_dir = Path(self.initialize_data_dir(root_dir))

        if data_seed is not None:
            state = np.random.get_state()
            np.random.seed(data_seed)
        
        _IMAGENET_ROOT = dataset_root['imagenet']
        self.ablations = {
            'tile': tile,
            'replace_with_gray': replace_with_gray,
            'replace_bbox_with_gray': lambda x, y: replace_with_gray(x,y,False)
        }
        
        self._split_dict = {i:c for c, i in enumerate(self.SPLITS)}
        self._split_names = {i:i for c, i in enumerate(self.SPLITS)}

        orig_train = HardImageNetBase(self._data_dir, _IMAGENET_ROOT, 'train', False)
        orig_test = HardImageNetBase(self._data_dir, _IMAGENET_ROOT, 'val', False)
        balanced_train = HardImageNetBase(self._data_dir, _IMAGENET_ROOT, 'train', True)

        df = pd.DataFrame({
            'ds_reference': ([orig_train] * len(orig_train) + [orig_test] * len(orig_test) + 
                [balanced_train] * len(balanced_train) + [orig_test] * (len(orig_test) * 3)),
            'split': (['orig_train']* len(orig_train) + ['test'] * len(orig_test) + 
                ['spur_balanced_train'] * len(balanced_train) + ['tile_test'] * len(orig_test) + 
                ['replace_with_gray_test'] * len(orig_test) + ['replace_bbox_with_gray_test'] * len(orig_test)
            )
        })

        df['ablation'] = None
        df.iloc[-len(orig_test)* 3:, df.columns.get_loc('ablation')] = (['tile'] * len(orig_test) + 
                ['replace_with_gray'] * len(orig_test) + ['replace_bbox_with_gray'] * len(orig_test))
        df['mask_path'] = (orig_train.mask_paths + orig_test.mask_paths + balanced_train.mask_paths +
                           orig_test.mask_paths * 3
        )
        df['imagenet_y'] = self.map_names_to_y(df['mask_path'].values, orig_train.wnid_to_idx)

        with open(os.path.join(self._data_dir, 'meta', 'hard_imagenet_idx.pkl'), 'rb') as f:
            inet_idx = pickle.load(f)

        df['y'] = df['imagenet_y'].apply(lambda x: inet_idx.index(x))

        # split train into train and val
        sub_idx = df[df.split == 'orig_train'].index
        sub_idx_train, sub_idx_val = train_test_split(sub_idx, test_size = val_pct, random_state = data_seed)
        df.loc[sub_idx_train, 'split'] = 'train'
        df.loc[sub_idx_val, 'split'] = 'val'

        sub_idx = df[df.split == 'spur_balanced_train'].index
        sub_idx_train, sub_idx_val = train_test_split(sub_idx, test_size = val_pct, random_state = data_seed)
        df.loc[sub_idx_train, 'split'] = 'spur_balanced_tr'
        df.loc[sub_idx_val, 'split'] = 'spur_balanced_va'

        self._meta_df = df
        self._original_resolution = (224, 224)
        self._n_classes = orig_train.num_classes
        self._y_array = torch.LongTensor(df['y'].values)
        self._y_size = 1

        df['split_idx'] = df['split'].map(self._split_dict)

        self._metadata_array = torch.stack(
            (torch.LongTensor( df['split_idx']), self._y_array),
            dim=1
        )
        self._metadata_fields = ['g', 'y']
        self._metadata_map = {
            'g': self.SPLITS,
            'y': np.unique(self._y_array.numpy())
        }

        self._eval_grouper = CombinatorialGrouper(
            dataset=self,
            groupby_fields=(['g']))
        
        self._split_scheme = split_scheme
        self._split_array = torch.LongTensor(df['split_idx'])
        self.df = df

        if data_seed is not None:
            np.random.set_state(state)

        super().__init__(self._data_dir, split_scheme)


    def map_names_to_y(self, mask_paths, wnid_to_idx):
        y = []
        for ind in range(len(mask_paths)):
            mask_path = mask_paths[ind]
            mask_path_suffix = mask_path.split('/')[-1]
            wnid = mask_path_suffix.split('_')[0]
            y.append(wnid_to_idx[wnid])
        return y

    def get_input(self, idx, return_mask = False):
        row = self.df.iloc[idx]
        img, mask, y = row['ds_reference'].__getitem__(None, row['mask_path'])
        if row['ablation'] is not None:
            img = self.ablations[row['ablation']](img.unsqueeze(0), mask.unsqueeze(0)).squeeze(0)
            
        if return_mask:
            return transforms.ToPILImage()(img.cpu()), mask
        else:
            return transforms.ToPILImage()(img.cpu())

    def eval(self, y_pred, y_true, metadata, prediction_fn=None):
        """
        Computes all evaluation metrics.
        Args:
            - y_pred (Tensor): Predictions from a model. By default, they are predicted labels (LongTensor).
                               But they can also be other model outputs such that prediction_fn(y_pred)
                               are predicted labels.
            - y_true (LongTensor): Ground-truth labels
            - metadata (Tensor): Metadata
            - prediction_fn (function): A function that turns y_pred into predicted labels 
        Output:
            - results (dictionary): Dictionary of evaluation metrics
            - results_str (str): String summarizing the evaluation metrics
        """
        metric = Accuracy(prediction_fn=prediction_fn)
        return self.standard_group_eval(
            metric,
            self._eval_grouper,
            y_pred, y_true, metadata)


'''
Following are adapted from https://github.com/mmoayeri/HardImageNet/
'''

def replace_with_gray(img, mask, keep_shape=True):
    gray = torch.ones_like(mask) * 0.5
    if keep_shape:
        out = img * (1-mask) + gray * mask
    else:
        out = []
        for i in range(img.shape[0]):
            bbox = get_bbox(mask[i,0])
            out.append(img[i] * (1-bbox) + gray[i] * bbox)
        out = torch.stack(out)
    return out

def replace_with_noise(img, mask):
    gray = torch.randn_like(mask) * 0.5
    return img * (1-mask) + gray * mask

def get_corners(arr):
    on_pixels = np.where(arr != 0)
    x_max, y_max = [np.max(on_pixels[i]) for i in [0,1]]
    x_min, y_min = [np.min(on_pixels[i]) for i in [0,1]]
    return x_min, x_max, y_min, y_max

def get_bbox(arr, expand=False):
    out = np.zeros_like(arr)
    if arr.sum() >0:
        x_min, x_max, y_min, y_max = get_corners(arr)
        out[x_min:x_max, y_min:y_max] = 1
    return out#, x_min, x_max, y_min, y_max

def num_nonzero_pixels(x):
    x[x!=0] = 1
    return x.sum()

def fill_instance_avg_surrounding_color(img, mask, bbox, keep_shape=False):
    '''
    img and mask: 3 x 224 x 224, bbox: 224x224
    We replace the bbox with the avg non-object pixel value in the bbox
    '''
    img_in_box = img * bbox
    object_in_box = img_in_box * mask
    mask_in_box = mask * bbox
    num_non_obj_pixels_in_box = num_nonzero_pixels(img_in_box) - num_nonzero_pixels(mask_in_box)

    sum_non_obj_pixels_in_box = img_in_box.flatten(1).sum(-1) - object_in_box.flatten(1).sum(-1)
    avg_color = sum_non_obj_pixels_in_box / num_non_obj_pixels_in_box * 3 # per color channel
    if keep_shape:
        obj_filled_in = torch.stack([mask_in_box[0]*avg_color[i] for i in range(avg_color.shape[0])])
        out = img * (1-mask_in_box) + obj_filled_in
    else:
        box_filled_in = torch.tensor(np.stack([bbox*float(avg_color[i]) for i in range(avg_color.shape[0])]))
        out = img * (1-bbox) + box_filled_in
    return out


def fill_with_avg_surrounding_color(img, mask, keep_shape=False):
    label, num_features = scipy.ndimage.label(mask[0].numpy())
    for i in range(1, num_features+1):
        instance_labels = label.copy()
        instance_labels[instance_labels != i] = 0
        bbox = get_bbox(instance_labels)[0]
        if bbox.sum() < 100:
            continue
        img = fill_instance_avg_surrounding_color(img, mask, bbox, keep_shape)
    return img

def trim_tile(img, mask, x1, x2, y1, y2, dir):
    _, h, w = img.shape
    x1, y1 = [max(a,0) for a in [x1, y1]]
    x2, y2 = [min(a,d) for a,d in zip([x2, y2], [h,w])]
    if mask[:, y1:y2, x1:x2].sum() == 0:
        out = img[:, y1:y2, x1:x2]
        size = (x2-x1) * (y2-y1) 
    else:
        # find first instance of other object
        if dir == 'left':
            is_there_obj_by_col = mask[0, y1:y2, x1:x2].sum(1)
            # we take the sum from col i leftwards (towards bbox) looking for lowest i where all leftward sums are 0 (no object)
            sum_moving_right = [sum(is_there_obj_by_col[i:]) for i in range(x2-x1)]
            furthest_we_can_go = sum_moving_right.index(0) if 0 in sum_moving_right else 0
            out = img[:, y1:y2, (x1+furthest_we_can_go):x2]
            size = furthest_we_can_go * (y2-y1) 
        elif dir == 'right':   
            is_there_obj_by_col = mask[0, y1:y2, x1:x2].sum(0)
            # now its sum from box to col i
            sum_moving_left = [sum(is_there_obj_by_col[:i]) for i in range(x2-x1,0,-1)]
            furthest_we_can_go = sum_moving_left.index(0) if 0 in sum_moving_left else 0
            out = img[:, y1:y2, x1:(x1+furthest_we_can_go)]
            size = furthest_we_can_go * (y2-y1)
        elif dir == 'up':   # actually down bc images have increasing y going downwards but whatever
            is_there_obj_by_col = mask[0, y1:y2, x1:x2].sum(0)
            # now its sum from box to row i
            sum_moving_down = [sum(is_there_obj_by_col[:i]) for i in range(y2-y1, 0, -1)]
            furthest_we_can_go = sum_moving_down.index(0) if 0 in sum_moving_down else 0
            out = img[:, y1:(y1+furthest_we_can_go), x1:x2]
            size = (x2-x1) * furthest_we_can_go
        elif dir == 'down':
            is_there_obj_by_col = mask[0, y1:y2, x1:x2].sum(1)
            # we take the sum from row i upwards (towards bbox) looking for lowest i where all upward sums are 0 (no object)
            sum_moving_up = [sum(is_there_obj_by_col[i:]) for i in range(y2-y1)]
            furthest_we_can_go = sum_moving_up.index(0) if 0 in sum_moving_up else 0
            out = img[:, (y1+furthest_we_can_go):y2, x1:x2]
            size = (x2-x1) * furthest_we_can_go

    return out, size

def repeat_tile_to_fill_bbox(tile, bbox_w, bbox_h, dir):
    out = torch.zeros(3, bbox_h, bbox_w)

    _, tile_h, tile_w = tile.shape
    
    if tile_h == 0 or tile_w == 0:
        return out
    
    if dir == 'right':
        num_tile_copies = bbox_w // tile_w
        tile = tile[:, :bbox_h, :]
        for i in range(num_tile_copies):
            out[:, :tile_h, i*tile_w:min(bbox_w, (i+1)*tile_w)] = tile
        if tile_h > 0 and bbox_w % tile_w != 0:
            out[:, :tile_h, (tile_w*num_tile_copies):] = tile[:,:,:(bbox_w % tile_w)]
    elif dir == 'left':
        num_tile_copies = bbox_w // tile_w
        tile = tile[:, :bbox_h, :]
        for i in range(num_tile_copies):
            out[:, :tile_h, max(0, bbox_w-(i+1)*tile_w):(bbox_w-i*tile_w)] = tile
        if tile_h > 0 and bbox_w % tile_w != 0:
            out[:, :tile_h, :(bbox_w % tile_w)] = tile[:,:,-1*(bbox_w % tile_w):]
    if dir == 'up':
        num_tile_copies = bbox_h // tile_h
        tile = tile[:, :, :tile_w]
        for i in range(num_tile_copies):
            out[:, i*tile_h:min(bbox_h, (i+1)*tile_h), :tile_w] = tile
        if tile_w > 0 and bbox_h % tile_h != 0:
            out[:, (tile_h*num_tile_copies):, :tile_w] = tile[:,:(bbox_h % tile_h),:]
    elif dir == 'down':
        num_tile_copies = bbox_h // tile_h
        tile = tile[:, :, :tile_w]
        for i in range(num_tile_copies):
            out[:, max(0, bbox_h-(i+1)*tile_h):(bbox_h-i*tile_h), :tile_w] = tile
        if tile_w > 0 and bbox_h % tile_h != 0:
            out[:, :(bbox_h % tile_h), :tile_w] = tile[:,-1*(bbox_h % tile_h):, :]
    return out


def largest_adjacent_tile(img, mask, bbox):
    '''
    Given a bounding box bbox, we check the four adjacent boxes of the same size.
    Each adjacent tile is cut off either at the image boundary or if there is another
    instance of the object (identified via mask). We then return the largest tile.
    '''
    x_min, x_max, y_min, y_max = get_corners(bbox)
    right_tile, size_r = trim_tile(img, mask, x_max, 2*x_max-x_min, y_min, y_max, dir='right')
    left_tile, size_l = trim_tile(img, mask, 2*x_min-x_max, x_min, y_min, y_max, dir='left')
    up_tile, size_u = trim_tile(img, mask, x_min, x_max, y_max, 2*y_max-y_min, dir='up')
    down_tile, size_d = trim_tile(img, mask, x_min, x_max, 2*y_min-y_max, y_max, dir='down')

    max_size_ind = np.argmax([size_r, size_l, size_u, size_d])
    biggest_tile = [right_tile, left_tile, up_tile, down_tile][max_size_ind]

    dirs = ['right', 'left', 'up', 'down']
    out = repeat_tile_to_fill_bbox(biggest_tile, (x_max-x_min), (y_max-y_min), dirs[max_size_ind])
    # print(dirs[max_size_ind])
    return out


def tile(img, mask):
    labels, num_features = scipy.ndimage.label(mask[0, 0])
    for j in range(1,1+num_features):
        labels2 = labels.copy()
        labels2[labels2 != j] = 0
        if labels2.sum() > 0:
            bbox = get_bbox(labels2)
            if bbox.sum() > 0:
                tile = largest_adjacent_tile(img[0], mask[0], bbox)
                x_min, x_max, y_min, y_max = get_corners(bbox)
                img[:, :, x_min:x_max, y_min:y_max] = tile.swapaxes(1,2)
    return img


def to_tens(img, mask):
    img, mask = [transforms.ToTensor()(x) for x in [img, mask]]
    return img, mask

class HardImageNetBase(Dataset):
    def __init__(self, _MASK_ROOT, _IMAGENET_ROOT, split='val', balanced_subset=False):

        with open(os.path.join(_MASK_ROOT, 'meta' ,'idx_to_wnid.pkl'), 'rb') as f:
            idx_to_wnid = pickle.load(f)
        self.wnid_to_idx = dict({v:k for k,v in idx_to_wnid.items()})

        self.aug = to_tens
        self.split = split
        self.balanced_subset = balanced_subset
        self._MASK_ROOT =  _MASK_ROOT
        self._IMAGENET_ROOT = _IMAGENET_ROOT
        self.collect_mask_paths()
        self.num_classes = 15

    def map_wnid_to_label(self, wnid):
        ind = self.wnid_to_idx[wnid]
        return ind

    def collect_mask_paths(self):
        if self.balanced_subset and self.split == 'train':
            # hard coded for now
            self.subset_size = 100

            with open(os.path.join(self._MASK_ROOT, 'meta', 'paths_by_rank.pkl'), 'rb') as f:
                ranked_paths = pickle.load(f)
            paths = []
            for c in ranked_paths:
                cls_paths = ranked_paths[c]
                paths += cls_paths[:self.subset_size] + cls_paths[(-1*self.subset_size):]
            self.mask_paths = [os.path.join(self._MASK_ROOT, 'train/', '_'.join(p.split('/')[-2:])) for p in paths]
            for p in self.mask_paths:
                if not os.path.exists(p):
                    self.mask_paths.remove(p)
        else:
            self.mask_paths = glob.glob(str(os.path.join(self._MASK_ROOT,  self.split))+'/*')

    def __getitem__(self, ind = None, mask_path = None):
        if ind is not None:
            mask_path = self.mask_paths[ind]
        mask_path_suffix = mask_path.split('/')[-1]
        wnid = mask_path_suffix.split('_')[0]
        fname = mask_path_suffix[len(wnid)+1:] #if self.split == 'val' else mask_path_suffix

        img_path = os.path.join(self._IMAGENET_ROOT, self.split, wnid, fname)
        img, mask = [Image.open(p) for p in [img_path, mask_path]]

        img, mask = self.aug(img, mask)

        if img.shape[0] > 3: #weird bug
            img, mask = [x[:3] for x in [img, mask]]    
                    
        if img.shape[0] == 1:
            img = torch.concat((img, img, img), axis = 0)

        class_ind = self.map_wnid_to_label(wnid)
        mask[mask > 0] = 1
        return img, mask, class_ind

    def __len__(self):
        return len(self.mask_paths)
