import os
import torch
import numpy as np
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import pandas as pd
import random
from tqdm import tqdm
import pickle
import shutil
import argparse
import sys
import csv
from collections import namedtuple

CSV = namedtuple("CSV", ["header", "index", "data"])

np.set_printoptions(threshold=sys.maxsize)
np.set_printoptions(suppress=True)

class CelebADataset(Dataset):
    def __init__(self, root_dir, split='train', transform=None, opt=None):

        self.root_dir = root_dir
        self.split = split
        self.transform = transform
        self.opt = opt
        self.split_idx = { 
            'train': 0,
            'valid': 1,
            'test': 2
        }

        self.split_ratio = {
            'train': [0.0, 0.8],
            'valid': [0.8, 0.9],
            'test': [0.9, 1.0],
        }


        self.targets_values = range(2)
        self.biases_values = range(2)

        self.load_filenames()
        self.load_labels()

        self.gen_dir = f"../ic_celeba/generated_images/{opt.target_attr}"        
        
        path_file = f"data/celeba_{opt.target_attr}_{split}_{self.opt.minority_to_keep}.pkl"

        if os.path.exists(path_file):
            data = pickle.load(open(path_file, "rb"))
            self.filenames = data['filenames']
            self.targets = data['targets']
            self.biases = data['biases']
            self.real_gen = data['real_gen']

        else: 
            self.load_labels()
            self.fix_filenames()
            self.set_group_counts()

            if split == 'train':
                self.make_harder(False)
                self.set_group_counts()
            
            if split == 'valid':
                self.make_harder(True)
                self.set_group_counts()

            # save filenames targets biases into a pickle file
            data = {
                'filenames': self.filenames,
                'targets': self.targets,
                'biases': self.biases,
                'real_gen': self.real_gen
            }
            pickle.dump(data, open(path_file, "wb"))

        self.set_group_counts()
        self.create_real_gen_weights()
        # self.get_class_distribution()

    def upweight_misclassified_samples(self, alpha, misclassified_filenames): 
        idx_misclassified = [self.filenames.index(fname) for fname in misclassified_filenames]
        #repeat the misclassified samples alpha times
        self.targets_misclassified = self.targets[idx_misclassified]
        self.biases_misclassified = self.biases[idx_misclassified]
        self.real_gen_misclassified = self.real_gen[idx_misclassified]

        #repeat targets_misclassified alpha times
        self.targets = np.concatenate((self.targets, np.tile(self.targets_misclassified,alpha)))
        self.biases = np.concatenate((self.biases, np.tile(self.biases_misclassified, alpha)))
        self.real_gen = np.concatenate((self.real_gen, np.tile(self.real_gen_misclassified, alpha)))
        self.filenames.extend([self.filenames[i] for i in idx_misclassified for _ in range(alpha)])

        self.create_real_gen_weights()


    def get_gen_ratio(self): 

        num_gen = np.sum(self.real_gen == 1)
        num_real = np.sum(self.real_gen == 0)

        print("Ratio REAL/GEN: ", num_gen/num_real)
        return num_gen, num_real

    def get_remove_idx(self, min_bias, counts_class, class_idx, min_enforce=True): 
            max_bias = 1 - min_bias 
        
            ratio_minority = 1 - self.opt.minority_to_keep 
            min_bias_new_count = int((ratio_minority * counts_class[max_bias])/(1-ratio_minority))

            if min_enforce:
                min_bias_new_count = max(min_bias_new_count, 15)
            to_remove_samples_num = int(counts_class[min_bias] - min_bias_new_count)

            idx_to_remove = np.where((self.targets == class_idx) & (self.biases == min_bias))[0]
            idx_to_remove = np.random.choice(idx_to_remove, to_remove_samples_num, replace=False)

            return idx_to_remove


    def take_subset_of_data(self, perc): 


        to_keep_all = [] 
        for bias_idx in self.biases_values: 
            for target_idx in self.targets_values: 

                idxs = np.where((self.targets == target_idx) & (self.biases == bias_idx))[0]
                np.random.shuffle(idxs)

                num_to_keep = max(int(perc * len(idxs)), 5)
                to_keep_all.extend(idxs[:num_to_keep])

        self.filenames = [self.filenames[i] for i in to_keep_all]
        self.targets = self.targets[to_keep_all]
        self.biases = self.biases[to_keep_all]
        self.real_gen = self.real_gen[to_keep_all]


    def cut_data(self, data_amount): 
        idxs = [] 
        for target_value in range(2):
            for bias_value in range(2):

                idx = np.where((self.targets == target_value) & (self.biases == bias_value))[0]
                
                start_idx = 0
                end_idx = data_amount[target_value][bias_value]

                idxs.extend(idx[start_idx:end_idx])

        self.filenames = [self.filenames[i] for i in idxs]
        self.targets = self.targets[idxs]
        self.biases = self.biases[idxs]
        self.real_gen = self.real_gen[idxs]


    def make_harder(self, min_enforce=True): 

        to_remove_samples_total = [] 

        min_group_positive = np.argmin(self.group_counts_[1])
        min_group_negative = 1 - min_group_positive

        to_remove_samples_total.extend(self.get_remove_idx(min_group_positive, self.group_counts_[1], 1, min_enforce))
        to_remove_samples_total.extend(self.get_remove_idx(min_group_negative, self.group_counts_[0], 0, min_enforce))            

        self.targets = np.delete(self.targets, to_remove_samples_total)
        self.biases = np.delete(self.biases, to_remove_samples_total)
        self.filenames = list(np.delete(np.array(self.filenames), to_remove_samples_total))

        self.real_gen = np.delete(self.real_gen, to_remove_samples_total)

    def load_gen_data_balance(self):

        to_add_all_filenames = []
        to_add_all_targets = [] 
        to_add_all_biases = [] 
        to_add_all_real_gen = [] 

        max_count = np.max(self.group_counts_) 
        for target_value in range(2):
            for bias_value in range(2): 
                to_add = int(max_count - self.group_counts_[target_value][bias_value])
                to_load_path = os.path.join(self.gen_dir, str(target_value), self.bias_names[bias_value])

                filenames = os.listdir(to_load_path)
                filenames = [os.path.join(to_load_path, filename) for filename in filenames]
                filenames = filenames[:to_add]

                to_add_all_targets.extend([target_value] * len(filenames))
                to_add_all_biases.extend([bias_value] * len(filenames))
                to_add_all_real_gen.extend([1] * len(filenames))
                to_add_all_filenames.extend(filenames)

        if self.opt.limit_to_gen: 
            self.filenames = to_add_all_filenames
            self.targets = np.array(to_add_all_targets)
            self.biases = np.array(to_add_all_biases)
            self.real_gen = np.array(to_add_all_real_gen)

        else: 
            self.filenames.extend(to_add_all_filenames)
            self.targets = np.concatenate((self.targets, np.array(to_add_all_targets)), axis=0)
            self.biases = np.concatenate((self.biases, np.array(to_add_all_biases)), axis=0)
            self.real_gen = np.concatenate((self.real_gen, np.array(to_add_all_real_gen)), axis=0)

        self.set_group_counts()
        self.create_real_gen_weights()

    def get_classes(self): 
        return self.targets_values

    def get_biases(self): 
        return self.biases_values

    def num_heads_di_real_gen(self): 
        return len(np.unique(self.bias_real_gen))

    def num_heads_di(self): 
        return len(np.unique(self.biases))

    def group_counts_dro(self): 
        group_counts = [] 
        for idx in np.unique(self.group_idx): 
            group_counts.append(np.sum(self.group_idx == idx))
        return group_counts

    def group_counts_dro_real_gen(self):
        group_counts = [] 
        for idx in np.unique(self.group_idx_real_gen): 
            group_counts.append(np.sum(self.group_idx_real_gen == idx))
        return group_counts

    def n_groups_dro(self): 
        return len(np.unique(self.group_idx))

    def n_groups_dro_real_gen(self):
        return len(np.unique(self.group_idx_real_gen))


    def limit_to_group(self, target_idx, target_value, bias_value): 
        idxs = np.where((self.targets[:, target_idx] == target_value) & (self.biases == bias_value))[0]

        self.filenames = [self.filenames[i] for i in idxs]
        self.targets = self.targets[idxs]
        self.biases = self.biases[idxs]

    def load_filenames(self):
        self.filenames = os.listdir(os.path.join(self.root_dir, "CelebA-HQ-img"))
        start_idx = int(self.split_ratio[self.split][0] * len(self.filenames))
        end_idx = int(self.split_ratio[self.split][1] * len(self.filenames))

        self.filenames = self.filenames[start_idx:end_idx]

    def fix_filenames(self): 
        self.filenames = [os.path.join(self.root_dir, 'CelebA-HQ-img', filename) for filename in self.filenames]

    def load_labels(self):
        self.labels = pd.read_csv(os.path.join(self.root_dir, 'CelebAMask-HQ-attribute-anno.txt'), delim_whitespace=True, skiprows=1, index_col=0)
        self.class_names = list(self.labels.columns.values)

        self.labels = self.labels.loc[self.filenames]
        self.labels = self.labels.to_numpy()
        self.labels[self.labels == -1] = 0
        
        self.targets = self.labels[:, self.class_names.index(self.opt.target_attr)]
        self.biases = self.labels[:, 20] #index Male
        self.bias_names = ["Female", "Male"]

        self.real_gen = np.zeros((len(self.filenames)))

        self.targets = self.targets.astype(int)
        self.real_gen = self.real_gen.astype(int)


    def load_image(self, filename):
        img = Image.open(filename)
        return img

    def get_class_distribution(self):
        for class_value_ in range(2): 
            for bias_value in self.biases_values: 

                class_index = self.targets == class_value_
                bias_idx = self.biases == bias_value

                num_images =np.sum(class_index & bias_idx)
                #print(f'Class Value {class_value_}, Bias {bias_value}:  {num_images / np.sum(class_index):.4f}')
                print(f'Class Value {class_value_}, Bias {bias_value}:  {num_images}')

        print("====================================")

    def set_group_counts(self): 
        self.group_counts_ = np.zeros((2,2))
        for class_value in range(2): 
            for bias_value in range(2): 
                class_index = self.targets == class_value
                bias_idx = self.biases == bias_value

                self.group_counts_[class_value, bias_value] = np.sum(class_index & bias_idx)

    def create_real_gen_weights(self, quit_ = False): 

        self.real_weights_groups = np.zeros_like(self.targets, dtype=np.float32)

        probs = self.group_counts_/np.sum(self.group_counts_)
        probs[probs == 0] = 1e-10
        probs = 1/probs

        for class_value in range(2): 
            for bias_value in range(2): 
                
                class_index = self.targets == class_value
                bias_idx = self.biases == bias_value

                self.real_weights_groups[class_index & bias_idx] = probs[class_value, bias_value]

        self.group_idx = np.zeros_like(self.targets, dtype=np.int32) 
        
        class_bias_pairs = [] 
        for class_label, bias_label in zip(self.targets, self.biases):
            if (class_label, bias_label) not in class_bias_pairs:
                class_bias_pairs.append((class_label, bias_label))
        
        for idx, class_bias_pair in enumerate(class_bias_pairs):
            class_value, bias_value = class_bias_pair

            class_index = self.targets == class_value
            bias_idx = self.biases == bias_value

            self.group_idx[class_index & bias_idx] = idx

        self.group_idx_real_gen = np.zeros_like(self.targets, dtype=np.int32) 
        class_bias_real_gen_pairs = []
        for class_label, bias_label, real_gen_label in zip(self.targets, self.biases, self.real_gen):
            if (class_label, bias_label, real_gen_label) not in class_bias_real_gen_pairs:
                class_bias_real_gen_pairs.append((class_label, bias_label, real_gen_label))

        for idx, class_bias_real_gen_pair in enumerate(class_bias_real_gen_pairs):
            class_value, bias_value, real_gen_value = class_bias_real_gen_pair

            class_index = self.targets == class_value
            bias_idx = self.biases == bias_value
            real_gen_idx = self.real_gen == real_gen_value

            self.group_idx_real_gen[class_index & bias_idx & real_gen_idx] = idx
        
        self.bias_real_gen = np.zeros_like(self.targets, dtype=np.int32)
        bias_real_gen_pairs = []

        for bias_label, real_gen_label in zip(self.biases, self.real_gen):
            if (bias_label, real_gen_label) not in bias_real_gen_pairs:
                bias_real_gen_pairs.append((bias_label, real_gen_label))
        
        for idx, bias_real_gen_pair in enumerate(bias_real_gen_pairs):
            bias_value, real_gen_value = bias_real_gen_pair

            bias_idx = self.biases == bias_value
            real_gen_idx = self.real_gen == real_gen_value

            self.bias_real_gen[bias_idx & real_gen_idx] = idx

    def load_gen_data(self):

        filenames_all = []
        targets_all = []
        biases_all = [] 

        for target_value in range(2): 
            for bias_value in range(2): 
                to_load_path = os.path.join(self.gen_dir, str(target_value), self.bias_names[bias_value])
                
                filenames = os.listdir(to_load_path)
                filenames = [os.path.join(to_load_path, filename) for filename in filenames]
                filenames = list(set(filenames) - set(self.filenames))

                start_idx = 0
                end_idx = self.opt.num_per_group 

                filenames_all.extend(filenames[start_idx:end_idx])
                targets_all.extend([target_value] * len(filenames[start_idx:end_idx]))
                biases_all.extend([bias_value] * len(filenames[start_idx:end_idx]))


        self.filenames.extend(filenames_all)
        self.targets = np.concatenate((self.targets, np.array(targets_all)), axis=0)
        self.biases = np.concatenate((self.biases, np.array(biases_all)), axis=0)
        self.real_gen = np.concatenate((self.real_gen, np.ones(len(targets_all))), axis=0)


        self.set_group_counts()
        self.create_real_gen_weights()

    def __len__(self):
        return len(self.filenames)


    def __getitem__(self, idx):
        filename = self.filenames[idx]
        target = self.targets[idx]
        bias = self.biases[idx]
        idx_img = filename.split('/')[-1].split('.')[0]
        real_gen = self.real_gen[idx]
        img = self.load_image(filename)
        group_idx = self.group_idx[idx]
        real_gen_weights = self.real_weights_groups[idx]
        group_idx_real_gen = self.group_idx_real_gen[idx]
        bias_real_gen = self.bias_real_gen[idx]

        if self.transform:
            img = self.transform(img)


        data = {
            'idx': idx,
            'img': img,
            'target': target,
            'bias': bias,
            'filenames': filename, 
            'real_gen': real_gen,
            'group_idx':group_idx,
            'real_weights_groups': real_gen_weights,
            'group_idx_real_gen': group_idx_real_gen,
            'bias_real_gen': bias_real_gen

        }

        return data

#test the dataloader
if __name__ == '__main__':
    #define the transform
    transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor()
    ])

    parser = argparse.ArgumentParser()
    parser.add_argument('--output_dir', type=str, default='test_records')
    parser.add_argument('--epochs', type=int, default=50)
    parser.add_argument('--seed', type=int, default=1)
    parser.add_argument('--batch_size', type=int, default=128, help='batch_size')
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--weight_decay', type=float, default=1e-5)
    parser.add_argument('--gen_balance', action='store_true')
    parser.add_argument('--make_harder', action='store_true')

    opt = parser.parse_args()

 
