from math import ceil
from PIL.Image import BICUBIC
from PIL import Image
import torchvision
from torchvision.transforms import Compose, RandomCrop, Pad, RandomHorizontalFlip, Resize, RandomAffine
from torchvision.transforms import ToTensor, Normalize

from torch.utils.data import Subset,Dataset, Sampler

import torchvision.utils as vutils
import random
from torch.utils.data import DataLoader
import numpy as np
import random
import pandas as pd
import os
import seaborn as sns
import matplotlib.pyplot as plt

from skimage import io
import torch
from torchvision import models
import torchvision
import torch.nn as nn
import numpy as np
import random
from IPython.display import clear_output
import time
import os.path
from sklearn.metrics import f1_score
import matplotlib.pyplot as plt
import torch.optim as optim
import torch.optim.lr_scheduler
import torch.nn.init
from torch.autograd import Variable
from sklearn.metrics import accuracy_score
from skimage.transform import resize
from sklearn.model_selection import train_test_split

import imageio
import numpy as np
import torch
from skimage import img_as_float32, img_as_ubyte
from skimage.transform import resize
from sklearn import preprocessing
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from torchvision.datasets.utils import download_and_extract_archive
from colorama import Fore


class EurosatDataset(torch.utils.data.Dataset):
    """
    EuroSAT: Land Use and Land Cover Classification with Sentinel-2
    Eurosat is a dataset and deep learning benchmark for land use and land cover classification. The dataset is based on Sentinel-2 satellite images covering 13 spectral bands and consisting out of 10 classes with in total 27,000 labeled and geo-referenced images.
    """

    def __init__(self, is_train, root_dir="data/EuroSAT/", transform=None, seed=42, download=False):
        """
        EurosatDataset
        Args:
            is_train (bool): If true returns training set, else test set.
            root_dir (str, optional): Root directory of dataset. Defaults to "data/EuroSAT/2750/".
            transform ([type], optional): Optional transform to be applied on a sample. Defaults to None.
            seed (int, optional): Seed used for train/test split. Defaults to 42.
            download (bool, optional): If True, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded it is not downloaded again. Defaults to False.
        """

        self.seed = seed
        self.root_dir = root_dir
        self.transform = transform
        self.is_train = is_train
        self.download = download

        self.size = [64, 64]
        self.num_channels = 3
        self.num_classes = 10
        self.test_ratio = 0.2
        self.N = 27000
        self.extaraced = '2750'
        self._load_data()

    def _load_data(self):
        """
        Loads the data from the passed root directory. Splits in test/train based on seed.
        Raises:
            RuntimeError: It will raise when folder not exists.
        """

        images = np.zeros(
            [self.N, self.size[0], self.size[1], 3], dtype="uint8")
        labels = []
        filenames = []

        if self.download:
            self.download_dataset()

        if not self._check_exists():
            raise RuntimeError(
                "Dataset not found. You can use download=True to download it"
            )

        i = 0
        data_dir = os.path.join(self.root_dir, self.extaraced)

        with tqdm(os.listdir(data_dir), bar_format="{l_bar}%s{bar}%s{r_bar}" % (Fore.GREEN, Fore.RESET)) as dir_bar:
            for item in dir_bar:
                f = os.path.join(data_dir, item)
                if os.path.isfile(f):
                    continue
                for subitem in os.listdir(f):
                    sub_f = os.path.join(f, subitem)
                    filenames.append(sub_f)

                    # a few images are a few pixels off, we will resize them
                    image = imageio.imread(sub_f)
                    if image.shape[0] != self.size[0] or image.shape[1] != self.size[1]:
                        # print("Resizing image...")
                        image = img_as_ubyte(
                            resize(
                                image, (self.size[0], self.size[1]), anti_aliasing=True)
                        )
                    images[i] = img_as_ubyte(image)
                    i += 1
                    labels.append(item)

                dir_bar.set_description(
                    f"{'Train' if self.is_train else 'Test'} images are reading..")
                dir_bar.set_postfix(category=item)

        labels = np.asarray(labels)
        filenames = np.asarray(filenames)

        # sort by filenames
        images = images[filenames.argsort()]
        labels = labels[filenames.argsort()]

        # convert to integer labels
        label_encoder = preprocessing.LabelEncoder()
        label_encoder.fit(np.sort(np.unique(labels)))
        labels = label_encoder.transform(labels)
        labels = np.asarray(labels)
        # remember label encoding
        self.label_encoding = list(label_encoder.classes_)

        # split into a is_train and test set as provided data is not presplit
        x_train, x_test, y_train, y_test = train_test_split(
            images,
            labels,
            test_size=self.test_ratio,
            random_state=self.seed,
            stratify=labels,
        )

        if self.is_train:
            self.data = x_train
            self.targets = y_train
        else:
            self.data = x_test
            self.targets = y_test

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img = self.data[idx]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        # img = Image.fromarray(img)

        if self.transform:
            img = self.transform(img)

        image = np.asarray(img / 255, dtype="float32")

        return image.transpose(2, 0, 1), self.targets[idx]

    def _check_exists(self) -> bool:
        """
        Check the Root directory is exists
        """
        return os.path.exists(self.root_dir)

    def download_dataset(self) -> None:
        """
        Download the dataset from the internet
        """

        if self._check_exists():
            return

        os.makedirs(self.root_dir, exist_ok=True)
        download_and_extract_archive(
            "https://madm.dfki.de/files/sentinel/EuroSAT.zip",
            download_root=self.root_dir,
            md5="c8fa014336c82ac7804f0398fcb19387",
        )

class BalancedSampler(Sampler):
    def __init__(self, buckets, retain_epoch_size=False):
        for bucket in buckets:
            random.shuffle(bucket)

        self.bucket_num = len(buckets)
        self.buckets = buckets
        self.bucket_pointers = [0 for _ in range(self.bucket_num)]
        self.retain_epoch_size = retain_epoch_size
    
    def __iter__(self):
        count = self.__len__()
        while count > 0:
            yield self._next_item()
            count -= 1

    def _next_item(self):
        bucket_idx = random.randint(0, self.bucket_num - 1)
        bucket = self.buckets[bucket_idx]
        item = bucket[self.bucket_pointers[bucket_idx]]
        self.bucket_pointers[bucket_idx] += 1
        if self.bucket_pointers[bucket_idx] == len(bucket):
            self.bucket_pointers[bucket_idx] = 0
            random.shuffle(bucket)
        return item

    def __len__(self):
        if self.retain_epoch_size:
            return sum([len(bucket) for bucket in self.buckets]) # Acrually we need to upscale to next full batch
        else:
            return max([len(bucket) for bucket in self.buckets]) * self.bucket_num # Ensures every instance has the chance to be visited in an epoch

def iloader(path):
    image = np.asarray((io.imread(path))/32000,dtype='float32')
    return image.transpose(2,0,1)


def Load_data():
    #root = '/local/data/sghosh_dg/Data/2750'
    EuroSat_Type = 'RGB'    # use 'RGB' or 'ALL' for type of Eurosat Dataset. Just change in this line. Rest of the code is managed for both type

    if EuroSat_Type == 'RGB':
      data = torchvision.datasets.DatasetFolder(root=root,loader = iloader, transform=None, extensions = 'jpg')
    elif EuroSat_Type == 'ALL':
      data = torchvision.datasets.DatasetFolder(root=root,loader = iloader, transform=None, extensions = 'tif')
    train_set, val_set = train_test_split(data, test_size=0.2, stratify=data.targets)
    #print(np.unique(train_set, return_counts=True))  #uncomment for class IDs
    #print(np.unique(val_set, return_counts=True))    #uncomment for class IDs
      
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=16, shuffle=True, num_workers=3, drop_last = True)
    val_loader = torch.utils.data.DataLoader(val_set, batch_size=16, shuffle=True, num_workers=0, drop_last = True)
    #test_loader = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=True, num_workers=0, drop_last = True)
    return train_set, val_set


def load_Eurosat(save_path = None, train_size=4000,train_rho=0.01,val_size=1000,val_rho=0.01,image_size=64,batch_size=128,num_workers=0,path='./data',num_classes=10,balance_val=False):
    train_transform = Compose([
        RandomCrop(64,padding=4),
        #Resize(image_size, BICUBIC),
        #RandomAffine(degrees=2, translate=(0.02, 0.02), scale=(0.98, 1.02), shear=2, fillcolor=(124,117,104)),
        RandomHorizontalFlip(),
        ToTensor(),])

    test_transform = Compose([ToTensor()])

    rho = train_rho
    data_dir = '/local/data/sghosh_dg/Data'
    seed = False
    train_dataset = EurosatDataset(is_train=True, seed=seed, root_dir=data_dir)

    print(f'Number of data on Train Dataset is {len(train_dataset)}')

    test_dataset = EurosatDataset(is_train=False, seed=seed, root_dir=data_dir)
    print(f'Number of data on Test Dataset is {len(test_dataset)}')

    train_x,train_y = np.array(train_dataset.data), np.array(train_dataset.targets)
    print(f"shape = {train_x.shape}")
    #test_x, test_y = test_dataset.data, test_dataset.targets
    #print(len(test_x))
    
    #exit(1)
    total_size=5000
    num_total_samples=[]
    num_train_samples=[]
    num_val_samples=[]

    if not balance_val:
        train_mu=train_rho**(1./9.)
        val_mu=val_rho**(1./9.)
        for i in range(num_classes):
            num_total_samples.append(ceil(total_size*(train_mu**i)))
            num_train_samples.append(ceil(train_size*(train_mu**i)))
            num_val_samples.append(ceil(val_size*(val_mu**i)))
            #num_val_samples.append(num_total_samples[-1]-num_train_samples[-1])
            #num_val_samples.append(round(val_size*(val_mu**i)))
    elif balance_val:
        train_mu=train_rho**(1./9.)
        for i in range(num_classes):
            num_val_samples.append(val_size)
            num_total_samples.append(ceil(total_size*(train_mu**i)))
            num_train_samples.append(ceil(train_size*(train_mu**i)))
            #num_train_samples.append(num_total_samples[-1]-num_val_samples[-1])

    train_index=[]
    val_index=[]
    #print(train_x,train_y)
    #print(num_train_samples,num_val_samples)
    for i in range(num_classes):
        train_index.extend(np.where(train_y==i)[0][:num_train_samples[i]])
        val_index.extend(np.where(train_y==i)[0][-num_val_samples[i]:])
    
    total_index=[]
    total_index.extend(train_index)
    total_index.extend(val_index)
    total_index=list(set(total_index))
    random.shuffle(total_index)
    train_x, train_y=train_x[total_index], train_y[total_index]

    train_index=[]
    val_index=[]
    #print(train_x,train_y)
    print(f"train histogram: {num_train_samples}, val histogram: {num_val_samples}")
    print(f"train histogram: {np.sum(num_train_samples)}, val histogram: {np.sum(num_val_samples)}")
    
    #exit(1)
    #print('sss')
    #print(df)
    for i in range(num_classes):
        train_index.extend(np.where(train_y==i)[0][:num_train_samples[i]])
        val_index.extend(np.where(train_y==i)[0][-num_val_samples[i]:])

    random.shuffle(train_index)
    random.shuffle(val_index)
    
    train_data,train_targets=train_x[train_index],train_y[train_index]
    val_data,val_targets=train_x[val_index],train_y[val_index]
    #print(f" train: {len(train_data)}, val: {len(val_data)}")

    #exit(1)
    train_dataset = CustomDataset(train_data,train_targets, train_transform)
    val_dataset = CustomDataset(val_data,val_targets, train_transform)
    train_eval_dataset = CustomDataset(train_data,train_targets, test_transform)
    val_eval_dataset = CustomDataset(val_data,val_targets, test_transform)
    


    train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers, 
                            shuffle=True, drop_last=False, pin_memory=True)

    val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers, 
                            shuffle=True, drop_last=False, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=5400, num_workers=num_workers, 
                            shuffle=False, drop_last=False, pin_memory=True)

    eval_train_loader = DataLoader(train_eval_dataset, batch_size=batch_size, num_workers=num_workers, 
                                shuffle=False, drop_last=False, pin_memory=True)
    eval_val_loader = DataLoader(val_eval_dataset, batch_size=len(val_data), num_workers=num_workers, 
                                shuffle=False, drop_last=False, pin_memory=True)

    return train_loader,val_loader,test_loader,eval_train_loader,eval_val_loader,num_train_samples,num_val_samples

class CustomDataset(Dataset):
    """CustomDataset with support of transforms.
    """
    def __init__(self, data, targets, transform=None):
        self.data = data
        self.targets = targets
        self.transform = transform

    def __getitem__(self, index):
        img, target = self.data[index], self.targets[index]
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        return img, target
    def __len__(self):
        return len(self.data)
#load_cifar10()