import os
import torch
import numpy as np
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import pandas as pd
import random
from tqdm import tqdm
import pickle
import shutil
import argparse
import sys
np.set_printoptions(threshold=sys.maxsize)
np.set_printoptions(suppress=True)


GROUPS = [(0, 0), (0, 1), (1, 0), (1, 1), (2, 2), (2, 3), (3, 2), (3, 3)]


class CelebADatasetGen(Dataset):
    def __init__(self, root_dir, split='train', transform=None, opt=None):

        self.root_dir = root_dir
        self.split = split
        self.transform = transform
        self.opt = opt
        self.split_idx = { 
            'train': 0,
            'valid': 1,
            'test': 2
        }
        self.split_ratio = {
            'train': [0.0, 0.7],
            'valid': [0.7, 0.85],
            'test': [0.85, 1.0],
        }

        self.class_to_name = {
            0:"landbird", 
            1:"waterbird",
            2:"small_dog", 
            3:"big_dog"
        }

        self.bias_to_name = { 
            0:"Land", 
            1:"Water", 
            2:"Indoors", 
            3:"Outdoors",
        }


        self.class_names = ["Male", "Female"]
        self.bias_names = ["old", "young"]

        self.gen_dir = f"./generated_images"        

        self.load_gen_data()
        self.set_group_counts()

        self.targets_values = range(4)
        self.biases_values = range(4)
        
        if split == 'train':
            self.create_real_gen_weights()
            self.get_class_distribution()


    def get_classes(self): 
        return self.targets_values

    def get_biases(self): 
        return self.biases_values

    def group_counts(self): 
        return self.group_counts_

    def num_heads_di(self): 
        return len(np.unique(self.biases))


    def load_gen_data(self):

        filenames_all = []
        targets_all = []
        biases_all = [] 

        for idx, (target_value, bias_value) in enumerate(GROUPS): 
            to_load_path = os.path.join(self.gen_dir, f"{self.class_to_name[target_value]}_{self.bias_to_name[bias_value]}")
            
            filenames = os.listdir(to_load_path)
            filenames = [os.path.join(to_load_path, filename) for filename in filenames]

            start_idx = int(len(filenames) * self.split_ratio[self.split][0])
            end_idx = int(len(filenames) * self.split_ratio[self.split][1])

            if self.split == "train": 
                end_idx = self.opt.num_per_group 
            
            filenames_all.extend(filenames[start_idx:end_idx])
            targets_all.extend([target_value] * len(filenames[start_idx:end_idx]))
            biases_all.extend([bias_value] * len(filenames[start_idx:end_idx]))

        self.filenames = filenames_all
        self.targets = np.array(targets_all)
        self.biases = np.array(biases_all)
    

    def load_image(self, filename):
        img = Image.open(filename)
        return img

    def get_class_distribution(self):
        for class_value_, bias_value in GROUPS: 

            class_index = self.targets == class_value_
            bias_idx = self.biases == bias_value

            num_images =np.sum(class_index & bias_idx)
            #print(f'Class Value {class_value_}, Bias {bias_value}:  {num_images / np.sum(class_index):.4f}')
            print(f'Class Value {class_value_}, Bias {bias_value}:  {num_images}')

        print("====================================")

    def set_group_counts(self): 
        self.group_counts_ = np.zeros((2,2))
        for class_value in range(2): 
            for bias_value in range(2): 
                class_index = self.targets == class_value
                bias_idx = self.biases == bias_value

                self.group_counts_[class_value, bias_value] = np.sum(class_index & bias_idx)

    def create_real_gen_weights(self, quit_ = False): 

        self.real_weights_groups = np.zeros_like(self.targets, dtype=np.float32)

        probs = self.group_counts_/np.sum(self.group_counts_)
        probs[probs == 0] = 1e-10
        probs = 1/probs

        for class_value in range(2): 
            for bias_value in range(2): 
                
                class_index = self.targets == class_value
                bias_idx = self.biases == bias_value

                self.real_weights_groups[class_index & bias_idx] = probs[class_value, bias_value]

        self.group_idx = np.zeros_like(self.targets, dtype=np.int32) 
        for class_value in range(2): 
            for bias_value in range(2): 
                
                class_index = self.targets == class_value
                bias_idx = self.biases == bias_value

                self.group_idx[class_index & bias_idx] = class_value + bias_value*(2)

        self.n_groups = 4


    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        filename = self.filenames[idx]
        target = self.targets[idx]
        bias = self.biases[idx]
        idx_img = filename.split('/')[-1].split('.')[0]
        img = self.load_image(filename)

        if self.transform:
            img = self.transform(img)


        data = {
            'idx': idx,
            'img': img,
            'target': target,
            'bias': bias,
            'filenames': filename, 
        }

        if self.split == 'train':
            data['real_weights_groups'] = self.real_weights_groups[idx]
            data['group_idx'] = self.group_idx[idx]

        return data

