import os
import torch
import numpy as np
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import pandas as pd
import random
from tqdm import tqdm
import pickle
import shutil
import argparse
import sys
np.set_printoptions(threshold=sys.maxsize)
np.set_printoptions(suppress=True)

class CelebADatasetGen(Dataset):
    def __init__(self, root_dir, split='train', transform=None, opt=None):

        self.root_dir = root_dir
        self.split = split
        self.transform = transform
        self.opt = opt
        self.split_idx = { 
            'train': 0,
            'valid': 1,
            'test': 2
        }
        self.split_ratio = {
            'train': [0.0, 0.7],
            'valid': [0.7, 0.85],
            'test': [0.85, 1.0],
        }

        self.bias_names = ["Female", "Male"]

        self.gen_dir = f"../ic_celeba/generated_images/{opt.target_attr}"        

        self.load_gen_data()
        self.set_group_counts()

        self.targets_values = range(2)
        self.biases_values = range(2)
        
        if split == 'train':
            self.create_real_gen_weights()
            self.get_class_distribution()


    def get_classes(self): 
        return self.targets_values

    def get_biases(self): 
        return self.biases_values

    def group_counts(self): 
        return self.group_counts_

    def num_heads_di_real_gen(self): 
        return len(np.unique(self.bias_real_gen))

    def num_heads_di(self): 
        return len(np.unique(self.biases))

    def group_counts_dro(self): 
        group_counts = [] 
        for idx in np.unique(self.group_idx): 
            group_counts.append(np.sum(self.group_idx == idx))
        return group_counts

    def group_counts_dro_real_gen(self):
        group_counts = [] 
        for idx in np.unique(self.group_idx_real_gen): 
            group_counts.append(np.sum(self.group_idx_real_gen == idx))
        return group_counts

    def n_groups_dro(self): 
        return len(np.unique(self.group_idx))

    def n_groups_dro_real_gen(self):
        return len(np.unique(self.group_idx_real_gen))

    def load_gen_data(self):

        filenames_all = []
        targets_all = []
        biases_all = [] 

        for target_value in range(2): 
            for bias_value in range(2): 
                to_load_path = os.path.join(self.gen_dir, str(target_value), self.bias_names[bias_value])
                
                filenames = os.listdir(to_load_path)
                filenames = [os.path.join(to_load_path, filename) for filename in filenames]

                start_idx = int(len(filenames) * self.split_ratio[self.split][0])
                end_idx = int(len(filenames) * self.split_ratio[self.split][1])

                if self.split == "train": 
                    end_idx = self.opt.num_per_group 
                
                filenames_all.extend(filenames[start_idx:end_idx])
                targets_all.extend([target_value] * len(filenames[start_idx:end_idx]))
                biases_all.extend([bias_value] * len(filenames[start_idx:end_idx]))

        self.filenames = filenames_all
        self.targets = np.array(targets_all)
        self.biases = np.array(biases_all)
    

    def load_image(self, filename):
        img = Image.open(filename)
        return img

    def get_class_distribution(self):
        for class_value_ in range(2): 
            for bias_value in self.biases_values: 

                class_index = self.targets == class_value_
                bias_idx = self.biases == bias_value

                num_images =np.sum(class_index & bias_idx)
                # print(f'Class Value {class_value_}, Bias {bias_value}:  {num_images / np.sum(class_index):.4f}')
                print(f'Class Value {class_value_}, Bias {bias_value}:  {num_images:.4f}')


        print("====================================")

    def set_group_counts(self): 
        self.group_counts_ = np.zeros((2,2))
        for class_value in range(2): 
            for bias_value in range(2): 
                class_index = self.targets == class_value
                bias_idx = self.biases == bias_value

                self.group_counts_[class_value, bias_value] = np.sum(class_index & bias_idx)

    def create_real_gen_weights(self, quit_ = False): 

        self.real_weights_groups = np.zeros_like(self.targets, dtype=np.float32)

        probs = self.group_counts_/np.sum(self.group_counts_)
        probs[probs == 0] = 1e-10
        probs = 1/probs

        for class_value in range(2): 
            for bias_value in range(2): 
                
                class_index = self.targets == class_value
                bias_idx = self.biases == bias_value

                self.real_weights_groups[class_index & bias_idx] = probs[class_value, bias_value]

        self.group_idx = np.zeros_like(self.targets, dtype=np.int32) 
        for class_value in range(2): 
            for bias_value in range(2): 
                
                class_index = self.targets == class_value
                bias_idx = self.biases == bias_value

                self.group_idx[class_index & bias_idx] = class_value + bias_value*(2)

        self.n_groups = 4


    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        filename = self.filenames[idx]
        target = self.targets[idx]
        bias = self.biases[idx]
        idx_img = filename.split('/')[-1].split('.')[0]
        img = self.load_image(filename)

        if self.transform:
            img = self.transform(img)


        data = {
            'idx': idx,
            'img': img,
            'target': target,
            'bias': bias,
            'filenames': filename, 
        }

        if self.split == 'train':
            data['real_weights_groups'] = self.real_weights_groups[idx]
            data['group_idx'] = self.group_idx[idx]

        return data


#test the dataloader
if __name__ == '__main__':
    #define the transform
    transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor()
    ])

    parser = argparse.ArgumentParser()
    parser.add_argument('--output_dir', type=str, default='test_records')
    parser.add_argument('--epochs', type=int, default=50)
    parser.add_argument('--seed', type=int, default=1)
    parser.add_argument('--batch_size', type=int, default=128, help='batch_size')
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--weight_decay', type=float, default=1e-5)
    parser.add_argument('--gen_balance', action='store_true')
    parser.add_argument('--make_harder', action='store_true')

    opt = parser.parse_args()

     
