#!/usr/bin/env python3

"https://github.com/alinlab/BAR/archive/refs/heads/master.zip"

import torch
import numpy as np
from torch.utils.data import Dataset
from torchvision import transforms
from typing import List, Tuple, Generator, Union
import os
import shutil
import sys
from os import path
from PIL import Image
from sklearn.model_selection import train_test_split
import pandas as pd
import zipfile
import requests

class BAR(Dataset):
    DOWNLOAD_URL = "https://api.github.com/repos/alinlab/BAR/zipball"
    DATASET_NAME = "bar"

    eval_transform = transforms.Compose([ # TODO: Check if people applies particular transforms 
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    train_transform = transforms.Compose([
        transforms.RandomResizedCrop((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])


    classes_str_to_idx = {
        "climbing": 0,
        "diving":   1,
        "fishing":  2,
        "racing":   3,
        "throwing": 4,
        "pole vaulting": 5,
    }

    def __init__(
        self,
        root_dir="./data",
        env="train",
        target_name="action",
        confounder_names="background",
        return_index=False,
        external_bias_labels: bool = True,
        **kwargs
    ) -> None:
        self.root = root_dir
        self.transform = BAR.eval_transform if env == "test" else BAR.train_transform
        
        if env == "val":
            self.env = "train"
        else:
            self.env = env 
        
        self.return_index = return_index
        self.target_name = target_name
        self.confounder_names = confounder_names

        self.root = "./data"

        if not os.path.isdir(os.path.join(self.root, "alinlab-BAR-1364b0a")):
            self.__download_dataset()
        else: self.root = os.path.join(self.root, "alinlab-BAR-1364b0a")

        self.samples = {}
        self.num_samples = 0

        self.samples_folder = path.join(self.root, self.env)
        self.bias_labels = None

        if external_bias_labels:
            if self.env == "train": 
                print("Loading external bias labels!")
                self.bias_labels = pd.read_csv("bar_metadata_aug.csv")["clip"].to_numpy()
            else:
                print("External bias labels are supported only for the training set, skipping...")

        for i, file in enumerate(sorted(os.listdir(self.samples_folder))):
            self.samples[i] = {
                "image_path": path.join(self.samples_folder, file),
                "class_label": int(BAR.classes_str_to_idx[file.split("_")[0]]),
                "bias_label": int(BAR.classes_str_to_idx[file.split("_")[0]]) if self.bias_labels is None else self.bias_labels[i]
            }
            self.num_samples += 1

        self.filename_array = np.array([self.samples[j]["image_path"] for j in range(len(self.samples))])
        self.y_array        = np.array([self.samples[j]["class_label"] for j in range(len(self.samples))])
        self.group_array    = np.array([self.samples[j]["bias_label"] for j in range(len(self.samples))])

        if self.env == "train":
            self.n_classes = 6  # Six classes
            self.n_confounders = 1  # Still one confounder
            self.n_groups = self.n_classes * 2  # 2 groups per class => 6 * 2 = 12 groups            
            self.group_array = (self.y_array * 2 + self.bias_labels).astype('int')
        else:
            self.n_classes = 6  # Six classes
            self.n_confounders = 0  
            self.n_groups = self.n_classes 
            self.group_array = (2 * self.y_array).astype("int")


        if self.env == "train":
            print(self.num_samples, len(self.bias_labels))

    def __len__(self):
        return len(self.y_array)
    
    def get_group_array(self):
        return self.group_array
    
    def get_label_array(self):
        return self.y_array
    
    def __getitem__(self, index: Union[int, slice, list]) -> Tuple[torch.Tensor]:
        if isinstance(index, slice):
            return [self.__getitem__(i) for i in range(*index.indices(len(self)))]
        
        if isinstance(index, list):
            return [self.__getitem__(idx) for idx in index]
        
        image = self.transform(Image.open(self.filename_array[index]).convert("RGB"))
        class_label = self.y_array[index]
        bias_label = self.group_array[index]

        return image, class_label, bias_label
    
    def __download_dataset(self) -> None:
        os.makedirs(self.root, exist_ok=True)
        output_path = os.path.join(self.root, "BAR-master.zip")
        print(f"=> Downloading {BAR.DATASET_NAME} for {BAR.DOWNLOAD_URL}")

        try:
            response = requests.get(BAR.DOWNLOAD_URL, stream=True, headers={'Accept-Encoding': None})
            response.raise_for_status()
            with open(output_path, mode="wb") as write_stream:
                for chunk in response.iter_content(chunk_size=8192):
                    write_stream.write(chunk)                

        except:
            raise RuntimeError("Unable to complete dataset download, check for your internet connection or try changing download link.")
        
        print(f"=> Extracting BAR-master.zip to directory {self.root}")
        try:
            with zipfile.ZipFile(output_path, mode="r") as unzipper:
                unzipper.extractall(self.root)
        except:
            raise RuntimeError(f"Unable to extract {output_path}, an error occured.")
        
        self.root = os.path.join(self.root, "alinlab-BAR-1364b0a")
        os.remove(output_path)
    
    def get_splits(self, splits, train_frac=1.0):
        subsets = {}
        for split in splits:
            assert split in ("train", "val", "test"), f"{split} is not a valid split"
            split_set = BAR(env=split, return_index=True)
            subsets[split] = split_set

        return subsets
    
    def group_str(self, group_idx):
        # Calculate the class (y) and confounder (c) for the group index
        y = group_idx // (self.n_groups // self.n_classes)  # Class label
        c = group_idx % (self.n_groups // self.n_classes)   # Confounder

        # Convert confounder to binary string format with leading zeros (based on n_confounders)
        bin_str = format(c, f"0{self.n_confounders}b")

        # Combine class and confounder into the group name
        group_name = f"Class: {int(y)}, Confounder: {bin_str}"
        
        return group_name
    
    def perclass_populations(self, return_labels: bool = False) -> Union[Tuple[float, float], Tuple[Tuple[float, float], torch.Tensor]]:
        labels: torch.Tensor = torch.zeros(len(self))
        for i in range(len(self)):
            labels[i] = self[i][1]

        _, pop_counts = labels.unique(return_counts=True)

        if return_labels:
            return pop_counts.long(), labels.long()

        return pop_counts
    
    def get_bias_labels(self) -> Generator[None, None, torch.Tensor]:
        for i in range(len(self)):
            yield self[i][2]

    
    def __repr__(self) -> str:
        return f"BAR(env={self.env}, bias_amount=Unknown, num_classes={self.num_classes})"
    



if __name__ == "__main__":    
    BAR()