import torch
from PIL import Image
import wilds
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import torchvision
from collections import Counter
from domainbed.datasets import MultipleDomainDataset
from torchvision import transforms
import random
from torchvision.transforms.functional import crop
from wilds.datasets.poverty_dataset import PovertyDataset


class POVERTY(MultipleDomainDataset):
    """
    Specific subset of WILDS containing 6 classes and 2 test locations.
    """
    def __init__(self, root='/data/ubuntu/robustness/data', split='train', aug='no_aug',algo='ERM'):
        dataset = PovertyDataset(root_dir=root)
        self.dataset = dataset.get_subset(split)
        # self.df = pd.read_csv(f'/work/lisabdunlap/DatasetUnderstanding/data/{split}_subset.csv')
        self.input_shape = (8, 224, 224,)
        self.num_classes = 62
        self.split = split
        self.algo = algo
        self.random_aug = True if '+R' in aug else False
        aug = aug.split('+R')[0]

        self.transform = None if self.algo == 'zeroshot'else self.get_transforms(aug) 

        self.label_names = ["airport", "airport_hangar", "airport_terminal", "amusement_park", "aquaculture", 
                            "archaeological_site", "barn", "border_checkpoint", "burial_site", "car_dealership", 
                            "construction_site", "crop_field", "dam", "debris_or_rubble", "educational_institution", 
                            "electric_substation", "factory_or_powerplant", "fire_station", "flooded_road", "fountain",
                              "gas_station", "golf_course", "ground_transportation_station", "helipad", "hospital",
                                "impoverished_settlement", "interchange", "lake_or_pond", "lighthouse", "military_facility", 
                                "multi-unit_residential", "nuclear_powerplant", "office_building", "oil_or_gas_facility", "park", 
                                "parking_lot_or_garage", "place_of_worship", "police_station", "port", "prison", "race_track", "railway_bridge", 
                                "recreational_facility", "road_bridge", "runway", "shipyard", "shopping_mall", "single-unit_residential", 
                                "smokestack", "solar_farm", "space_facility", "stadium", "storage_tank", "surface_mine", "swimming_pool", 
                                "toll_booth", "tower", "tunnel_opening", "waste_disposal", "water_treatment_facility", "wind_farm", "zoo"]

        if algo in ['ADA', 'ME_ADA'] and split=='train':
            self.samples = [(data, label) for data, label, _ in self.dataset]
            i = random.randint(0,10000)
            ada_root = os.path.join(root,f'ADA_{i}')
            while os.path.exists(ada_root):
                i=random.randint(0,10000)
                ada_root = os.path.join(root,f'ADA_{i}')

            self.ada_root = ada_root
            os.makedirs(self.ada_root)
            self.ada_samples = []

        self.ori_samples = None
    
    def build_imgnet(self,img_size):
        self.arg1 = random.uniform(0.3,0.7) if self.random_aug else 0.08
        self.arg2 = 1
        
        return transforms.Compose([
            transforms.RandomResizedCrop(img_size,scale=(self.arg1, 1)),
            #transforms.RandomHorizontalFlip(), #not mentioned in Appendix of the deepmind paper
            transforms.ColorJitter(brightness=0.4,contrast=0.4,saturation=0.4),
            #transforms.ToTensor(),
        ])
    
    def build_augmix(self,img_size):
        self.arg1 = random.randint(2,8) if self.random_aug else 7 
        self.arg2 = random.randint(2,8) if self.random_aug else 5
        return transforms.AugMix(self.arg1,self.arg2)#3,3 #PIL image recommended. For torch tensor, it should be of torch.uint8

    def build_randaug(self,img_size):
        self.arg1 = random.randint(2,5) if self.random_aug else 2 
        self.arg2 = random.randint(5,15) if self.random_aug else 9
        return transforms.RandAugment(self.arg1,self.arg2) #2,9#PIL image recommended. For torch tensor, it should be of torch.uint8
    
    def build_autoaug(self,img_size):

        policy = transforms.autoaugment.AutoAugmentPolicy.IMAGENET

        return transforms.AutoAugment(policy=policy) #PIL image recommended. For torch tensor, it should be of torch.uint8

    def get_transforms(self,mode='no_aug', gray=False):
        #assert mode in ['no_aug','imgnet','augmix','randaug','autoaug'], 'incorrect preprocessing mode'
        
        transforms_list = []

        if mode != 'no_aug' and self.split == 'train':
            if mode != 'imgnet':
                transforms_list.append(transforms.ToTensor())
                transforms_list.append(lambda x : transforms.functional.convert_image_dtype(x,torch.uint8))
            transforms_list.append(getattr(self,f'build_{mode}')(self.input_shape[1]))
            if mode != 'imgnet':
                transforms_list.append(lambda x : transforms.functional.convert_image_dtype(x,torch.float32))
        
        if mode not in['augmix','autoaug','randaug'] or self.split!='train':
            transforms_list.append(transforms.ToTensor())
        
        preprocess = transforms.Compose(transforms_list)

        return preprocess
    
    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
            
        if 'ADA' in self.algo and self.split == 'train':
            img, label = self.samples[idx]
        else:
            img, label, _ = self.dataset[idx]
        #img = Image.open(img_path)
        if self.transform:
            img = self.transform(img)
        #location = self.location_labels[idx]
        if self.algo in ['BPA','PnD', 'OccamNets'] and self.split == 'train':
            return img, label,idx
        return img, label#, location
