import logging
import os
import pickle
from pathlib import Path

import PIL
import numpy as np
import torch
import torch.utils.data
from debias.datasets.utils import TwoCropTransform, get_confusion_matrix
from torch.utils.data.sampler import WeightedRandomSampler
from torchvision import transforms
from debias.utils.utils import us_alg_bin, us_alg_ce

import random

class UTKFace:
    def __init__(self, root, transform, **kwargs):
        self.root = root
        self.filenames = os.listdir(self.root)
        self.transform = transform

    def __getitem__(self, index):
        filename = self.filenames[index]
        X = PIL.Image.open(os.path.join(self.root, filename))
        age = int(filename.split('_')[0])

        if self.transform is not None:
            X = self.transform(X)

        return X, age

    def __len__(self):
        return len(self.filenames)


class BiasedUTKFace:
    def __init__(self, root, transform, split,
                 bias_attr='race', bias_rate=0.9,
                 under_sample=False, 
                 diff_analysis=0.0, 
                 **kwargs):
        self.root = Path(root) / 'images'
        filenames = np.array(os.listdir(self.root))
        np.random.shuffle(filenames)
        num_files = len(filenames)
        num_train = int(num_files * 0.8)
        target_attr = 'gender'

        self.transform = transform
        self.target_attr = target_attr
        self.bias_rate = bias_rate
        self.bias_attr = bias_attr
        self.train = split == 'train'

        save_path = Path(root) / 'pickles' / f'biased_utk_face-target_{target_attr}-bias_{bias_attr}-{bias_rate}'
        if save_path.is_dir():
            print(f'use existing biased_utk_face from {save_path}')
            data_split = 'train' if self.train else 'test'
            self.files, self.targets, self.bias_targets = pickle.load(open(save_path / f'{data_split}_dataset.pkl', 'rb'))
            if split in ['valid', 'test']:
                save_path = Path(f'clusters/utk_face_rand_indices_{bias_attr}.pkl')
                if not save_path.exists():
                    rand_indices = torch.randperm(len(self.targets))
                    pickle.dump(rand_indices, open(save_path, 'wb'))
                else:
                    rand_indices = pickle.load(open(save_path, 'rb'))
                num_total = len(rand_indices)
                num_valid = int(0.5 * num_total)
                
                if split == 'valid':
                    indices = rand_indices[:num_valid]
                elif split == 'test':
                    indices = rand_indices[num_valid:]
                
                indices = indices.numpy()
                
                self.files = self.files[indices]
                self.targets = self.targets[indices]
                self.bias_targets = self.bias_targets[indices]
        else:
            train_dataset = self.build(filenames[:num_train], train=True)
            test_dataset = self.build(filenames[num_train:], train=False)

            print(f'save biased_utk_face to {save_path}')
            save_path.mkdir(parents=True, exist_ok=True)
            pickle.dump(train_dataset, open(save_path / f'train_dataset.pkl', 'wb'))
            pickle.dump(test_dataset, open(save_path / f'test_dataset.pkl', 'wb'))

            self.files, self.targets, self.bias_targets = train_dataset if self.train else test_dataset
        
        targets_ = torch.zeros((len(self.targets), 2)) 
        targets_[torch.arange((len(self.targets,))), self.targets] = 1 
        self.targets_bin = targets_ 
        
        self.calculate_bias_weights()
        self.set_min_idxs()

        self.targets, self.bias_targets, self.groups_counts = torch.from_numpy(self.targets).long(), torch.from_numpy(
            self.bias_targets).long(), torch.from_numpy(self.groups_counts).float()

        idx_shuffle = torch.randperm(self.targets.size()[0])
        self.targets = self.targets[idx_shuffle]
        self.bias_targets = self.bias_targets[idx_shuffle]
        self.files = self.files[idx_shuffle]
        self.groups_counts = self.groups_counts[idx_shuffle]
        self.targets_bin = self.targets_bin[idx_shuffle]
        
        
        if under_sample: 
            self.targets_bin_ = torch.clone(self.targets_bin) 
            self.targets_ = torch.clone(self.targets)
            self.bias_targets_ = torch.clone(self.bias_targets) 
            self.groups_counts_ = torch.clone(self.groups_counts)
            self.files_ = self.files[:] 

        if split == 'train' and under_sample == 'bin': 
            self.under_sample_bin() 

        if split == 'train' and under_sample == 'analysis':
            self.under_sample_bin(diff=diff_analysis) 
            
        # if split == 'train' and under_sample == 'paper': 
        #     self.under_sample_paper() 
                
        if split == 'train' and under_sample == 'ce': 
            self.under_sample_ce() 
        
        if split == 'train' and under_sample == 'os': 
            self.over_sample_ce() 
        
        self.calculate_class_imbalance_weights()
        self.confusion_matrix_org, self.confusion_matrix, self.confusion_matrix_by = get_confusion_matrix(num_classes=2,
                                                                                                          targets=self.targets,
                                                                                                          biases=self.bias_targets)

        
        print(f'Use BiasedUTKFace - target_attr: {target_attr}')

        print(
            f'BiasedUTKFace -- total: {len(self.files)}, target_attr: {self.target_attr}, bias_attr: {self.bias_attr} ' \
            f'bias_rate: {self.bias_rate}')

        print(
            [f'[{split}] target_{i}-bias_{j}: {sum((self.targets == i) & (self.bias_targets == j))}' for i in (0, 1) for
             j in (0, 1)])

    def build(self, filenames, train=False):
        attr_dict = {
            'age': (0, lambda x: x >= 20, lambda x: x <= 10,),
            'gender': (1, lambda x: x == 0, lambda x: x == 1),
            'race': (2, lambda x: x == 0, lambda x: x != 0),
        }
        assert self.target_attr in attr_dict.keys()
        target_cls_idx, *target_filters = attr_dict[self.target_attr]
        bias_cls_idx, *bias_filters = attr_dict[self.bias_attr]

        target_classes = self.get_class_from_filename(filenames, target_cls_idx)
        bias_classes = self.get_class_from_filename(filenames, bias_cls_idx)

        total_files = []
        total_targets = []
        total_bias_targets = []

        for i in (0, 1):
            major_idx = np.where(target_filters[i](target_classes) & bias_filters[i](bias_classes))[0]
            minor_idx = np.where(target_filters[1 - i](target_classes) & bias_filters[i](bias_classes))[0]
            np.random.shuffle(minor_idx)

            num_major = major_idx.shape[0]
            num_minor_org = minor_idx.shape[0]
            if train:
                num_minor = int(num_major * (1 - self.bias_rate))
            else:
                num_minor = minor_idx.shape[0]
            num_minor = min(num_minor, num_minor_org)
            num_total = num_major + num_minor

            majors = filenames[major_idx]
            minors = filenames[minor_idx][:num_minor]

            total_files.append(np.concatenate((majors, minors)))
            total_bias_targets.append(np.ones(num_total) * i)
            total_targets.append(np.concatenate((np.ones(num_major) * i, np.ones(num_minor) * (1 - i))))

        files = np.concatenate(total_files)
        targets = np.concatenate(total_targets)
        bias_targets = np.concatenate(total_bias_targets)
        return files, targets, bias_targets


    def get_bin_gc(self): 
        gc = torch.sum(self.targets_bin!=-1, dim=0)
        gc = gc/torch.sum(gc)
        gc = 1/gc
        return gc

    def get_class_from_filename(self, filenames, cls_idx):
        return np.array([int(fname.split('_')[cls_idx]) if len(fname.split('_')) == 4 else 10 for fname in filenames])

    def count_pos_neg(self, targets, bias): 

        count_pos = [0, 0] 
        count_neg = [0, 0] 
        
        neg_idx = {0:[], 1:[]} 
        pos_idx = {0:[], 1:[]} 
        for idx, bi, tar in zip(range(len(bias)), bias, targets): 
            if tar == 1: 
                count_pos[int(bi)] += 1 
                pos_idx[int(bi)].append(idx)
            elif tar == 0: 
                count_neg[int(bi)] += 1
                neg_idx[int(bi)].append(idx) 
        
        return count_pos, count_neg, neg_idx, pos_idx 

    def calculate_class_imbalance_weights(self): 
        self.imbalance_weights = torch.ones_like(self.targets_bin)
        for i in range(2): 
            sum_0 = torch.sum(self.targets_bin[:, i] == 0).float()
            sum_1 = torch.sum(self.targets_bin[:, i] == 1).float()
            self.imbalance_weights[self.targets_bin[:, i] == 0, i] = sum_0/(sum_0 + sum_1)
            self.imbalance_weights[self.targets_bin[:, i] == 1, i] = sum_1/(sum_0 + sum_1)

        self.imbalance_weights = 1/self.imbalance_weights


    def calculate_bias_weights(self): 
        
        groups_counts = [0]*4
        class_counts = [0]*4
        for t, b in zip(self.targets, self.bias_targets): 
            idx = t + (b*2) 
            idx = int(idx)
            groups_counts[idx] += 1 
            class_counts[int(t)] += 1 
            class_counts[int(t)+2] += 1

        groups_counts = np.array(groups_counts) 
        groups_counts = groups_counts/np.sum(groups_counts)         
        groups_counts = 1/groups_counts 
        groups_weights = np.zeros((len(self.targets)))
        
        for idx_, t, b in zip(range(len(self.targets)), self.targets, self.bias_targets): 
            idx = t + (b*2)
            idx = int(idx)
            groups_weights[idx_] = groups_counts[idx] 
       
        self.groups_counts = groups_weights
    

    def set_min_idxs(self): 
        
        self.min_idxs = []
        count_pos, count_neg, neg_idx, pos_idx = self.count_pos_neg(
                                                self.targets, 
                                                self.bias_targets) 
        self.min_idxs.append(np.argmin(count_neg))
        self.min_idxs.append(np.argmin(count_pos))


    def under_sample_ce(self): 
        
        self.targets = torch.clone(self.targets_) 
        self.bias_targets = torch.clone(self.bias_targets_) 
        self.groups_counts = torch.clone(self.groups_counts_) 
        self.files = self.files_[:] 
        
        count_pos, count_neg, neg_idx, pos_idx = self.count_pos_neg(
                                                self.targets, 
                                                self.bias_targets) 
        
        print('Counts Positive: ', count_pos) 
        print('Counts Negative: ', count_neg) 
        
        count_neg, count_pos = us_alg_ce(count_neg, count_pos)
        groups_samples = [] 
        
        groups_samples.extend(count_pos) 
        groups_samples.extend(count_neg) 
        
        self.set_groups_samples_ce(groups_samples)
        count_pos, count_neg, _, _ = self.count_pos_neg(
                                         self.targets, 
                                         self.bias_targets) 
        print('Counts Positive: ', count_pos) 
        print('Counts Negative: ', count_neg) 
        
    
    def under_sample_bin(self, diff = 0):
        for target_idx in range(2):
            
            self.targets_bin[:, target_idx] = self.targets_bin_[:, target_idx] 
            count_pos, count_neg, neg_idx, pos_idx = self.count_pos_neg(
                                                    self.targets_bin[:, target_idx], 
                                                    self.bias_targets) 
            
            print('Target %d'%target_idx)
            print('Counts Positive: ', count_pos) 
            print('Counts Negative: ', count_neg) 

            
            new_count_neg = us_alg_bin(count_neg, count_pos, diff)
            groups_samples = [] 
            
            groups_samples.extend(count_pos) 
            groups_samples.extend(new_count_neg) 
            
            self.set_groups_samples_bin(groups_samples, target_idx)
            count_pos, count_neg, _, _ = self.count_pos_neg(
                                             self.targets_bin[:, target_idx], 
                                             self.bias_targets) 
            print('After US: Counts Positive: ', count_pos) 
            print('After US: Counts Negative: ', count_neg) 

    
    def over_sample_ce(self, verbose=False): 
        self.targets = torch.clone(self.targets_) 
        self.bias_targets = torch.clone(self.bias_targets_) 
        self.groups_counts = torch.clone(self.groups_counts_) 
        self.targets_bin = torch.clone(self.targets_bin_)
        self.files = self.files_[:] 

        count_pos, count_neg, neg_idx, pos_idx = self.count_pos_neg(
                                        self.targets, 
                                        self.bias_targets) 

        total_counts = count_pos + count_neg
        max_count = np.max(total_counts)
        to_keep_samples = [] 
        for bias_idx in range(2): 
            to_keep_samples.extend(np.random.choice(pos_idx[bias_idx], max_count, replace=True))
            to_keep_samples.extend(np.random.choice(neg_idx[bias_idx], max_count, replace=True))

        self.targets = self.targets[to_keep_samples]
        self.bias_targets = self.bias_targets[to_keep_samples]
        self.groups_counts = self.groups_counts[to_keep_samples]
        self.files = self.files[to_keep_samples]
        self.targets_bin = self.targets_bin[to_keep_samples]

        count_pos, count_neg, neg_idx, pos_idx = self.count_pos_neg(
                                        self.targets, 
                                        self.bias_targets) 

    def set_to_keep(self, to_keep_idx): 
        self.targets = self.targets[to_keep_idx] 
        self.bias_targets = self.bias_targets[to_keep_idx] 
        self.files = self.files[to_keep_idx] 
        self.groups_counts = self.groups_counts[to_keep_idx]

    def set_groups_samples_ce(self, groups_samples): 
        count_pos, count_neg, neg_idx, pos_idx = self.count_pos_neg(
                                                        self.targets, 
                                                        self.bias_targets) 
        all_idx = [] 
        all_idx.append(pos_idx[0]) 
        all_idx.append(pos_idx[1])
        all_idx.append(neg_idx[0])
        all_idx.append(neg_idx[1]) 
    
        to_keep_idx = [] 
        for group_idx, group_sample_num in zip(all_idx, groups_samples):
            to_keep_idx.extend(random.sample(group_idx, group_sample_num)) 
        
        self.set_to_keep(to_keep_idx)


    # def under_sample_paper(self):
    #     for target_idx in range(2):
            
    #         self.targets_bin[:, target_idx] = self.targets_bin_[:, target_idx] 
    #         count_pos, count_neg, neg_idx, pos_idx = self.count_pos_neg(
    #                                                 self.targets_bin[:, target_idx], 
    #                                                 self.bias_targets) 
            
    #         print('Target %d'%target_idx)
    #         print('Counts Positive: ', count_pos) 
    #         print('Counts Negative: ', count_neg) 

    #         new_count_neg = us_alg_bin(count_neg, count_pos)
    #         groups_samples = [] 
    #         groups_samples.extend(count_pos) 
    #         groups_samples.extend(new_count_neg) 
        
    #         '''
    #         if target_idx == 0: 
    #             new_count_neg = us_alg_bin(count_neg, count_pos)
    #             groups_samples = [] 
    #             groups_samples.extend(count_pos) 
    #             groups_samples.extend(new_count_neg) 
            
    #         else: 
    #             new_count_pos = us_alg_bin(count_pos, count_neg)
    #             groups_samples = [] 
    #             groups_samples.extend(new_count_pos) 
    #             groups_samples.extend(count_neg) 
    #         '''

    #         self.set_groups_samples_bin(groups_samples, target_idx)
    #         count_pos, count_neg, _, _ = self.count_pos_neg(
    #                                          self.targets_bin[:, target_idx], 
    #                                          self.bias_targets) 
    #         print('After US: Counts Positive: ', count_pos) 
    #         print('After US: Counts Negative: ', count_neg) 

    def set_groups_samples_bin(self, groups_samples, target_idx): 
        count_pos, count_neg, neg_idx, pos_idx = self.count_pos_neg(
                                                        self.targets_bin[:, target_idx], 
                                                        self.bias_targets) 
        all_idx = [] 
        all_idx.append(pos_idx[0]) 
        all_idx.append(pos_idx[1])
        all_idx.append(neg_idx[0])
        all_idx.append(neg_idx[1]) 

        to_keep_idx = [] 
        for group_idx, group_sample_num in zip(all_idx, groups_samples):
            to_keep_idx.extend(random.sample(group_idx, group_sample_num)) 
        
        full_idxs = torch.arange((len(self.targets)))
        to_select = torch.ones((len(self.targets)))
        to_select[to_keep_idx] = 0 
        full_idxs = full_idxs[to_select.bool()] 

        self.targets_bin[full_idxs, target_idx] = -1
           

    def __getitem__(self, index):
        filename, target, bias, target_bin = self.files[index], int(self.targets[index]), int(self.bias_targets[index]), self.targets_bin[index]
        
        gc = self.groups_counts[index]
        gc_imbalance = self.imbalance_weights[index]
        X = PIL.Image.open(os.path.join(self.root, filename))

        if self.transform is not None:
            X = self.transform(X)
        

        return X, target_bin, bias, index, target, gc, gc_imbalance

    def __len__(self):
        return len(self.files)


def get_utk_face(root, batch_size, split, bias_attr='race', bias_rate=0.9, num_workers=4,
                  aug=False, image_size=64, two_crop=False, ratio=0, given_y=True,under_sample=False, repair=False, diff_analysis=0.0):

    logging.info(f'get_utk_face - split: {split}, aug: {aug}, given_y: {given_y}, ratio: {ratio}')
    size_dict = {64: 72, 128: 144, 224: 256}
    load_size = size_dict[image_size]
    train = split == 'train'

    if train:
        if aug:
            transform = transforms.Compose([
                transforms.RandomResizedCrop(size=image_size, scale=(0.2, 1.)),
                transforms.RandomHorizontalFlip(),
                transforms.RandomApply([
                    transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
                ], p=0.8),
                transforms.RandomGrayscale(p=0.2),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
            ])
        else:
            transform = transforms.Compose([
                transforms.RandomResizedCrop(image_size),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))])

    else:
        transform = transforms.Compose([
            transforms.Resize(load_size),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))])

    if two_crop:
        transform = TwoCropTransform(transform)

    dataset = BiasedUTKFace(root, transform=transform, split=split, bias_rate=bias_rate, bias_attr=bias_attr, under_sample=under_sample, diff_analysis=diff_analysis)

    def clip_max_ratio(score):
        upper_bd = score.min() * ratio
        return np.clip(score, None, upper_bd)

    if ratio != 0:
        if given_y:
            weights = [1 / dataset.confusion_matrix_by[c, b] for c, b in zip(dataset.targets, dataset.bias_targets)]
        else:
            weights = [1 / dataset.confusion_matrix[b, c] for c, b in zip(dataset.targets, dataset.bias_targets)]
        if ratio > 0:
            weights = clip_max_ratio(np.array(weights))
        sampler = WeightedRandomSampler(weights, len(weights), replacement=True)
    else:
        sampler = None

    if repair:
        shuffle=False 
        
    elif sampler is None:
        shuffle = True 
    else: 
        shuffle = False

    dataloader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        # shuffle=True if sampler is None else False,
        shuffle=shuffle,
        sampler=sampler,
        num_workers=num_workers,
        pin_memory=False,
        drop_last=two_crop
    )

    return dataloader
