import os
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from os.path import exists
from PIL import Image as im
from torchvision import transforms
from torch.utils.data import Dataset


def img_train_transform():
    train_transform_list = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.RandomApply([transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)], p=0.8),
        transforms.RandomGrayscale(p=0.2),
        transforms.PILToTensor(),
        transforms.ConvertImageDtype(torch.float),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])
    return train_transform_list


def img_val_transform():
    val_transform_list = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.PILToTensor(),
        transforms.ConvertImageDtype(torch.float),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])
    return val_transform_list


class GeoDataset(Dataset):
    """
    DataLoader for image-gps datasets.

    The expected CSV file with the dataset information should have columns:
    - 'IMG_FILE' for the image filename,
    - 'LAT' for latitude, and
    - 'LON' for longitude.

    Attributes:
        dataset_folder (str): Base folder where images are stored.
        dataset_file (str): CSV file path containing image names and GPS coordinates.
        transform (callable, optional): Optional transform to be applied on a sample.
    """

    def __init__(self, dataset_file, dataset_folder, device, train=False):
        self.dataset_folder = dataset_folder
        self.device = device
        if train:
            self.transform = img_train_transform()
        else:
            self.transform = img_val_transform()
        self.images, self.coordinates = self.load_dataset(dataset_file)

    def load_dataset(self, dataset_file):
        try:
            dataset_info = pd.read_csv(dataset_file)
        except Exception as e:
            raise IOError(f"Error reading {dataset_file}: {e}")

        images = []
        coordinates = []

        for _, row in tqdm(dataset_info.iterrows(), desc="Loading image paths and coordinates"):
            filename = os.path.join(self.dataset_folder, row['IMG_ID'])
            if exists(filename):
                images.append(filename)
                latitude = float(row['LAT']) / 90
                longitude = float(row['LON']) / 180
                coordinates.append((latitude, longitude))

        return images, coordinates

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        gps = self.coordinates[idx]

        image = im.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        gps = torch.tensor(gps, dtype=torch.float)

        return image, gps

class PretrainedDataset(Dataset):

    def __init__(self, dataset_file, train=True, use_augmentation=False, img_perturb=0.00, gps_purturb=0.00):
        self.train = train
        self.img_perturb, self.gps_perturb = img_perturb, gps_purturb
        self.images, self.coordinates = self.load_dataset(dataset_file)
        self.use_augmentation = use_augmentation

    def load_dataset(self, dataset_file):
        embeddings = np.load(dataset_file)["embedding"]
        coordinates = np.load(dataset_file)["location"]
        print(embeddings.shape)
        if len(embeddings.shape) == 3:
            self.augmented = True
            # embeddings = embeddings[:,[0,3,4]]
            # print(embeddings.shape)
            self.num_augmentation = embeddings.shape[1]
            self.embedding_dim = embeddings.shape[2]
        else:
            self.augmented = False
            self.embedding_dim = embeddings.shape[1]
        # embeddings /= np.linalg.norm(embeddings, axis=1).reshape((-1, 1))

        return embeddings, coordinates

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_perturbation = self.img_perturb * np.random.rand(self.embedding_dim)
        gps_perturbation = self.gps_perturb * np.random.rand(2)
        if self.augmented and self.use_augmentation:
            aug_idx = np.random.choice(self.num_augmentation)
            embedding = torch.tensor(self.images[idx, aug_idx] + img_perturbation, dtype=torch.float)
        elif self.augmented and not self.use_augmentation:
            embedding = torch.tensor(self.images[idx, 0], dtype=torch.float)
        else:
            embedding = torch.tensor(self.images[idx], dtype=torch.float)
        gps = torch.tensor(self.coordinates[idx] + gps_perturbation, dtype=torch.float)

        return embedding, gps, idx