import os
import torch
import numpy as np
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import pandas as pd
import random
from tqdm import tqdm
import pickle
import shutil
import argparse
import sys
import csv
from collections import namedtuple
from utils import *

CSV = namedtuple("CSV", ["header", "index", "data"])

np.set_printoptions(threshold=sys.maxsize)
np.set_printoptions(suppress=True)

class CelebADataset(Dataset):
    def __init__(self, root_dir, split='train', transform=None, opt=None):

        self.bias_rate = 0.
        self.root_dir = root_dir
        self.split = split
        self.transform = transform
        self.opt = opt
        self.split_idx = { 
            'train': 0,
            'valid': 1,
            'test': 2
        }

        self.split_ratio = {
            'train': [0.0, 0.8],
            'valid': [0.8, 0.9],
            'test': [0.9, 1.0],
        }

        self.target_attr = "gender"
        self.bias_attr = "age"

        self.class_names = ["Male", "Female"]
        self.bias_names = ["old", "young"]


        self.targets_values = range(2)
        self.biases_values = range(2)

        self.gen_dir = f"./generated_images/"        
    
        if os.path.exists(f"data/utk_face_{split}_{self.opt.minority_to_keep}.pkl"):
            data = pickle.load(open(f"data/utk_face_{split}_{self.opt.minority_to_keep}.pkl", "rb"))
            self.filenames = data['filenames']
            self.targets = data['targets']
            self.biases = data['biases']
            self.real_gen = data['real_gen']

        else: 
        
            self.load_filenames()
            self.load_data()
            self.fix_filenames()
            self.set_group_counts()

            if split == "train":
                self.make_harder(False)
                self.set_group_counts()

            if split == 'valid':
                self.make_harder(True)
                self.set_group_counts()
            # save filenames targets biases into a pickle file
            data = {
                'filenames': self.filenames,
                'targets': self.targets,
                'biases': self.biases,
                'real_gen': self.real_gen
            }
            pickle.dump(data, open(f"data/utk_face_{split}_{self.opt.minority_to_keep}.pkl", "wb"))

        self.set_group_counts()
        self.create_real_gen_weights()
        self.targets_bin = get_targets_bin(self.targets)

    def get_target_distro(self, target): 
        num_biases = len(np.unique(self.biases)) 
        target_distro = [] 
        for bias in range(num_biases): 
            target_distro.append(np.sum(np.logical_and(self.targets == target,
                                                     self.biases == bias)))
        
        return target_distro


    def get_kept_indices(self, target, target_prime, target_prime_new_distro):

        to_keep_indices = [] 
        for bias, bias_distro in enumerate(target_prime_new_distro):
            tmp = np.logical_and(self.targets == target_prime, self.biases == bias)
            indices_bias = list(np.arange(len(self.targets))[tmp])
            to_keep_indices.extend(random.sample(indices_bias, bias_distro))
        
        return to_keep_indices

    def bias_mimick(self): 
        
        num_targets = len(np.unique(self.targets))
        num_biases = len(np.unique(self.biases)) 

        for target in range(num_targets): 
            target_distro = self.get_target_distro(target)
            to_keep_indices = [] 
            for target_prime in range(num_targets): 
                
                if target_prime == target: 
                    indices_target = list(np.arange(len(self.targets))[self.targets == target] )
                    to_keep_indices.extend(indices_target)
                else: 
                    target_prime_distro = self.get_target_distro(target_prime)
                    target_prime_new_distro = solve_linear_program(target_distro, target_prime_distro, self.biases)
                    to_keep_indices.extend(self.get_kept_indices(target, target_prime, target_prime_new_distro))
            
            full_idxs = np.arange((len(self.targets)))
            to_select = np.ones((len(self.targets)))
            to_select[to_keep_indices] = 0 
            full_idxs = full_idxs[to_select.astype(np.bool)] 

            self.targets_bin[full_idxs, target] = -1

    def upweight_misclassified_samples(self, alpha, misclassified_filenames): 
        idx_misclassified = [self.filenames.index(fname) for fname in misclassified_filenames]
        #repeat the misclassified samples alpha times
        self.targets_misclassified = self.targets[idx_misclassified]
        self.biases_misclassified = self.biases[idx_misclassified]
        self.real_gen_misclassified = self.real_gen[idx_misclassified]

        #repeat targets_misclassified alpha times
        self.targets = np.concatenate((self.targets, np.tile(self.targets_misclassified,alpha)))
        self.biases = np.concatenate((self.biases, np.tile(self.biases_misclassified, alpha)))
        self.real_gen = np.concatenate((self.real_gen, np.tile(self.real_gen_misclassified, alpha)))
        self.filenames.extend([self.filenames[i] for i in idx_misclassified for _ in range(alpha)])

        self.create_real_gen_weights()

    def get_class_from_filename(self, filenames, cls_idx):
        return np.array([int(fname.split('_')[cls_idx]) if len(fname.split('_')) == 4 else 10 for fname in filenames])

    def get_gen_ratio(self): 

        num_gen = np.sum(self.real_gen == 1)
        num_real = np.sum(self.real_gen == 0)

        print("Ratio REAL/GEN: ", num_gen/num_real)
        return num_gen, num_real

    def get_remove_idx(self, min_bias, counts_class, class_idx, min_enforce=True): 
            max_bias = 1 - min_bias 
        
            ratio_minority = 1 - self.opt.minority_to_keep 
            min_bias_new_count = int((ratio_minority * counts_class[max_bias])/(1-ratio_minority))
            
            if min_enforce: 
                min_bias_new_count = max(min_bias_new_count, 15)            
            to_remove_samples_num = int(counts_class[min_bias] - min_bias_new_count)

            idx_to_remove = np.where((self.targets == class_idx) & (self.biases == min_bias))[0]
            idx_to_remove = np.random.choice(idx_to_remove, to_remove_samples_num, replace=False)

            return idx_to_remove

    def cut_data(self, data_amount): 
        idxs = [] 
        for target_value in range(2):
            for bias_value in range(2):

                idx = np.where((self.targets == target_value) & (self.biases == bias_value))[0]
                
                start_idx = 0
                end_idx = data_amount[target_value][bias_value]

                idxs.extend(idx[start_idx:end_idx])

        self.filenames = [self.filenames[i] for i in idxs]
        self.targets = self.targets[idxs]
        self.biases = self.biases[idxs]
        self.real_gen = self.real_gen[idxs]


    def load_filenames(self):
        self.filenames = os.listdir(self.root_dir)
        start_idx = int(self.split_ratio[self.split][0] * len(self.filenames))
        end_idx = int(self.split_ratio[self.split][1] * len(self.filenames))

        self.filenames = self.filenames[start_idx:end_idx]

    def make_harder(self, min_enforce=True): 

        to_remove_samples_total = [] 

        min_group_positive = np.argmin(self.group_counts_[1])
        min_group_negative = 1 - min_group_positive

        to_remove_samples_total.extend(self.get_remove_idx(min_group_positive, self.group_counts_[1], 1, min_enforce))
        to_remove_samples_total.extend(self.get_remove_idx(min_group_negative, self.group_counts_[0], 0, min_enforce))            

        self.targets = np.delete(self.targets, to_remove_samples_total)
        self.biases = np.delete(self.biases, to_remove_samples_total)
        self.filenames = list(np.delete(np.array(self.filenames), to_remove_samples_total))

        self.real_gen = np.delete(self.real_gen, to_remove_samples_total)

    def load_gen_data(self):

        filenames_all = []
        targets_all = []
        biases_all = [] 

        filenames_all = []
        targets_all = []
        biases_all = [] 

        for target_value in range(2): 
            for bias_value in range(2): 
                to_load_path = os.path.join(self.gen_dir, self.class_names[target_value], self.bias_names[bias_value])
                
                filenames = os.listdir(to_load_path)
                filenames = [os.path.join(to_load_path, filename) for filename in filenames]

                start_idx = int(len(filenames) * self.split_ratio[self.split][0])
                end_idx = int(len(filenames) * self.split_ratio[self.split][1])

                if self.split == "train": 
                    end_idx = self.opt.num_per_group 
                
                filenames_all.extend(filenames[start_idx:end_idx])
                targets_all.extend([target_value] * len(filenames[start_idx:end_idx]))
                biases_all.extend([bias_value] * len(filenames[start_idx:end_idx]))


        self.filenames.extend(filenames_all)
        self.targets = np.concatenate((self.targets, np.array(targets_all)), axis=0)
        self.biases = np.concatenate((self.biases, np.array(biases_all)), axis=0)
        self.real_gen = np.concatenate((self.real_gen, np.ones(len(targets_all))), axis=0)


        self.set_group_counts()
        self.create_real_gen_weights()

    def load_gen_data_balance(self):

        to_add_all_filenames = []
        to_add_all_targets = [] 
        to_add_all_biases = [] 
        to_add_all_real_gen = [] 

        max_count = np.max(self.group_counts_) 
        for target_value in range(2):
            for bias_value in range(2): 
                to_add = int(max_count - self.group_counts_[target_value][bias_value])
                to_load_path = os.path.join(self.gen_dir, self.class_names[target_value], self.bias_names[bias_value])

                filenames = os.listdir(to_load_path)
                filenames = [os.path.join(to_load_path, filename) for filename in filenames]
                filenames = filenames[:to_add]

                to_add_all_targets.extend([target_value] * len(filenames))
                to_add_all_biases.extend([bias_value] * len(filenames))
                to_add_all_real_gen.extend([1] * len(filenames))
                to_add_all_filenames.extend(filenames)

        if self.opt.limit_to_gen: 
            self.filenames = to_add_all_filenames
            self.targets = np.array(to_add_all_targets)
            self.biases = np.array(to_add_all_biases)
            self.real_gen = np.array(to_add_all_real_gen)

        else: 
            self.filenames.extend(to_add_all_filenames)
            self.targets = np.concatenate((self.targets, np.array(to_add_all_targets)), axis=0)
            self.biases = np.concatenate((self.biases, np.array(to_add_all_biases)), axis=0)
            self.real_gen = np.concatenate((self.real_gen, np.array(to_add_all_real_gen)), axis=0)

        self.set_group_counts()
        self.create_real_gen_weights()

    def get_classes(self): 
        return self.targets_values

    def get_biases(self): 
        return self.biases_values

    def num_heads_di_real_gen(self): 
        return len(np.unique(self.bias_real_gen))

    def num_heads_di(self): 
        return len(np.unique(self.biases))

    def group_counts_dro(self): 
        group_counts = [] 
        for idx in np.unique(self.group_idx): 
            group_counts.append(np.sum(self.group_idx == idx))
        return group_counts

    def group_counts_dro_real_gen(self):
        group_counts = [] 
        for idx in np.unique(self.group_idx_real_gen): 
            group_counts.append(np.sum(self.group_idx_real_gen == idx))
        return group_counts

    def n_groups_dro(self): 
        return len(np.unique(self.group_idx))

    def n_groups_dro_real_gen(self):
        return len(np.unique(self.group_idx_real_gen))

    def load_data(self):

        attr_dict = {
            'age': (0, lambda x: x >= 20, lambda x: x <= 10,),
            'gender': (1, lambda x: x == 0, lambda x: x == 1),
            'race': (2, lambda x: x == 0, lambda x: x != 0),
        }

        assert self.target_attr in attr_dict.keys()
        target_cls_idx, *target_filters = attr_dict[self.target_attr]
        bias_cls_idx, *bias_filters = attr_dict[self.bias_attr]

        target_classes = self.get_class_from_filename(self.filenames, target_cls_idx)
        bias_classes = self.get_class_from_filename(self.filenames, bias_cls_idx)

        total_files = []
        total_targets = []
        total_bias_targets = []

        for i in (0, 1):
            major_idx = np.where(target_filters[i](target_classes) & bias_filters[i](bias_classes))[0]
            minor_idx = np.where(target_filters[1 - i](target_classes) & bias_filters[i](bias_classes))[0]
            np.random.shuffle(minor_idx)

            num_major = major_idx.shape[0]
            num_minor_org = minor_idx.shape[0]
            if self.split == "train":
                num_minor = int(num_major * (1 - self.bias_rate))
            else:
                num_minor = minor_idx.shape[0]
            num_minor = min(num_minor, num_minor_org)
            num_total = num_major + num_minor

            majors = [filename for i, filename in enumerate(self.filenames) if i in major_idx]
            minors = [filename for i, filename in enumerate(self.filenames) if i in minor_idx]
            minors = minors[:num_minor]

            total_files.append(np.concatenate((majors, minors)))
            total_bias_targets.append(np.ones(num_total) * i)
            total_targets.append(np.concatenate((np.ones(num_major) * i, np.ones(num_minor) * (1 - i))))

        self.filenames = np.concatenate(total_files)
        self.targets = np.concatenate(total_targets).astype(np.int64)
        self.biases = np.concatenate(total_bias_targets).astype(np.int64)
        self.real_gen = np.zeros((len(self.filenames, )))

    def fix_filenames(self):
        self.filenames = [os.path.join(self.root_dir, filename) for filename in self.filenames]

    def load_image(self, filename):
        img = Image.open(filename)
        return img

    def get_class_distribution(self):
        for class_value_ in range(2): 
            for bias_value in self.biases_values: 

                class_index = self.targets == class_value_
                bias_idx = self.biases == bias_value

                num_images =np.sum(class_index & bias_idx)
                #print(f'Class Value {class_value_}, Bias {bias_value}:  {num_images / np.sum(class_index):.4f}')
                print(f'Class Value {class_value_}, Bias {bias_value}:  {num_images}')

        print("====================================")


    def get_class_distribution_bin(self):
        for target_bin in range(2): 
            for class_value_ in range(2): 
                for bias_value in self.biases_values: 

                    num_images = np.logical_and(self.targets == class_value_, self.biases == bias_value)
                    num_images = np.logical_and(num_images, self.targets_bin[:, target_bin] != -1)
                    num_images = np.sum(num_images)

                    print(f'Target Bin {target_bin} Class Value {class_value_}, Bias {bias_value}:  {num_images}')

        print("====================================")


    def set_group_counts(self): 
        self.group_counts_ = np.zeros((2,2))
        for class_value in range(2): 
            for bias_value in range(2): 
                class_index = self.targets == class_value
                bias_idx = self.biases == bias_value

                self.group_counts_[class_value, bias_value] = np.sum(class_index & bias_idx)

    def create_real_gen_weights(self, quit_ = False): 

        self.real_weights_groups = np.zeros_like(self.targets, dtype=np.float32)

        probs = self.group_counts_/np.sum(self.group_counts_)
        probs[probs == 0] = 1e-10
        probs = 1/probs

        for class_value in range(2): 
            for bias_value in range(2): 
                
                class_index = self.targets == class_value
                bias_idx = self.biases == bias_value

                self.real_weights_groups[class_index & bias_idx] = probs[class_value, bias_value]

        self.group_idx = np.zeros_like(self.targets, dtype=np.int32) 
        
        class_bias_pairs = [] 
        for class_label, bias_label in zip(self.targets, self.biases):
            if (class_label, bias_label) not in class_bias_pairs:
                class_bias_pairs.append((class_label, bias_label))
        
        for idx, class_bias_pair in enumerate(class_bias_pairs):
            class_value, bias_value = class_bias_pair

            class_index = self.targets == class_value
            bias_idx = self.biases == bias_value

            self.group_idx[class_index & bias_idx] = idx

        self.group_idx_real_gen = np.zeros_like(self.targets, dtype=np.int32) 
        class_bias_real_gen_pairs = []
        for class_label, bias_label, real_gen_label in zip(self.targets, self.biases, self.real_gen):
            if (class_label, bias_label, real_gen_label) not in class_bias_real_gen_pairs:
                class_bias_real_gen_pairs.append((class_label, bias_label, real_gen_label))

        for idx, class_bias_real_gen_pair in enumerate(class_bias_real_gen_pairs):
            class_value, bias_value, real_gen_value = class_bias_real_gen_pair

            class_index = self.targets == class_value
            bias_idx = self.biases == bias_value
            real_gen_idx = self.real_gen == real_gen_value

            self.group_idx_real_gen[class_index & bias_idx & real_gen_idx] = idx

        
        self.bias_real_gen = np.zeros_like(self.targets, dtype=np.int32)
        bias_real_gen_pairs = []

        for bias_label, real_gen_label in zip(self.biases, self.real_gen):
            if (bias_label, real_gen_label) not in bias_real_gen_pairs:
                bias_real_gen_pairs.append((bias_label, real_gen_label))
        
        for idx, bias_real_gen_pair in enumerate(bias_real_gen_pairs):
            bias_value, real_gen_value = bias_real_gen_pair

            bias_idx = self.biases == bias_value
            real_gen_idx = self.real_gen == real_gen_value

            self.bias_real_gen[bias_idx & real_gen_idx] = idx


    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        filename = self.filenames[idx]
        target = self.targets[idx]
        # target_bin = self.targets_bin[idx]
        bias = self.biases[idx]
        idx_img = filename.split('/')[-1].split('.')[0]
        real_gen = self.real_gen[idx]
        img = self.load_image(filename)
        group_idx = self.group_idx[idx]
        real_gen_weights = self.real_weights_groups[idx]
        group_idx_real_gen = self.group_idx_real_gen[idx]
        bias_real_gen = self.bias_real_gen[idx]

        if self.transform:
            img = self.transform(img)


        data = {
            'idx': idx,
            'img': img,
            'target': target,
            # 'target_bin': target_bin,
            'bias': bias,
            'filenames': filename, 
            'real_gen': real_gen,
            'group_idx':group_idx,
            'real_weights_groups': real_gen_weights,
            'group_idx_real_gen': group_idx_real_gen, 
            'bias_real_gen': bias_real_gen
        }

        return data


#test the dataloader
if __name__ == '__main__':
    #define the transform
    transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor()
    ])

    parser = argparse.ArgumentParser()
    parser.add_argument('--output_dir', type=str, default='test_records')
    parser.add_argument('--epochs', type=int, default=50)
    parser.add_argument('--seed', type=int, default=1)
    parser.add_argument('--batch_size', type=int, default=128, help='batch_size')
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--weight_decay', type=float, default=1e-5)
    parser.add_argument('--gen_balance', action='store_true')
    parser.add_argument('--make_harder', action='store_true')

    opt = parser.parse_args()

