import logging
import pickle
from pathlib import Path

import numpy as np
import torch
from debias.datasets.utils import TwoCropTransform, get_confusion_matrix
from torch.utils.data import WeightedRandomSampler
from torch.utils.data.dataloader import DataLoader
from torchvision import transforms as T
from torchvision.datasets.celeba import CelebA
from debias.utils.utils import us_alg_bin, us_alg_ce

import random 

# import torch.multiprocessing
# torch.multiprocessing.set_sharing_strategy('file_system')

class BiasedCelebASplit:
    def __init__(self, root, split, transform, target_attr, under_sample, resample_blonde, diff_analysis, **kwargs):
        self.transform = transform
        self.target_attr = target_attr
        
        self.celeba = CelebA(
            root=root,
            split="train" if split == "train_valid" else split,
            target_type="attr",
            transform=transform,
        )
        self.bias_idx = 20
        
        if target_attr == 'blonde':
            self.target_idx = 9
            if split in ['train', 'train_valid'] and resample_blonde:
                save_path = Path(root) / 'pickles' / 'blonde'
                if save_path.is_dir():
                    print(f'use existing blonde indices from {save_path}')
                    self.indices = pickle.load(open(save_path / 'indices.pkl', 'rb'))
                else:
                    self.indices = self.build_blonde()
                    print(f'save blonde indices to {save_path}')
                    save_path.mkdir(parents=True, exist_ok=True)
                    pickle.dump(self.indices, open(save_path / f'indices.pkl', 'wb'))
                self.attr = self.celeba.attr[self.indices]
            else:
                self.attr = self.celeba.attr
                self.indices = torch.arange(len(self.celeba))

        
        elif target_attr == 'makeup':
            self.target_idx = 18
            self.attr = self.celeba.attr
            self.indices = torch.arange(len(self.celeba))
        
        elif target_attr == 'black':
            self.target_idx = 8
            self.attr = self.celeba.attr
            self.indices = torch.arange(len(self.celeba))
        
        else:
            raise AttributeError
            
        if split in ['train', 'train_valid']:
            
            rand_indices = torch.randperm(len(self.indices))
            
            num_total = len(rand_indices)
            num_train = int(0.8 * num_total)
            
            if split == 'train':
                indices = rand_indices[:num_train]
            elif split == 'train_valid':
                indices = rand_indices[num_train:]
            
            self.indices = self.indices[indices]
            self.attr = self.attr[indices]
        
        
        self.targets = self.attr[:, self.target_idx]
        self.biases = self.attr[:, self.bias_idx]
        
        self.calculate_bias_weights()
        self.groups_counts = torch.Tensor(self.groups_counts)    

        #self.targets_original = torch.clone(self.targets) 
        
        targets_ = torch.zeros((len(self.targets), 2)) 
        targets_[torch.arange((len(self.targets,))), self.targets] = 1 
        self.targets_bin = targets_ 
        
        self.set_min_idxs()
        
        if under_sample: 
            self.targets_bin_ = torch.clone(self.targets_bin) 
            self.targets_ = torch.clone(self.targets)
            self.biases_ = torch.clone(self.biases) 
            self.indices_ = torch.clone(self.indices) 
            self.groups_counts_ = torch.clone(self.groups_counts) 

        if split == 'train' and under_sample == 'bin': 
            self.under_sample_bin() 
        
        if split == 'train' and under_sample == 'ce': 
            self.under_sample_ce() 
        
        if split == 'train' and under_sample == 'analysis':
            self.under_sample_bin(diff=diff_analysis) 
        
        if split == 'train' and under_sample == 'paper':
            self.under_sample_paper()
            
        if split == 'train' and under_sample == 'os': 
            self.over_sample_ce() 

        self.calculate_class_imbalance_weights()
        self.confusion_matrix_org, self.confusion_matrix, self.confusion_matrix_by = get_confusion_matrix(num_classes=2,
                                                                                                          targets=self.targets,
                                                                                                          biases=self.biases)
        
        print(f'Use BiasedCelebASplit \n target_attr: {target_attr} split: {split} \n {self.confusion_matrix_org}')

    def calculate_class_imbalance_weights(self): 
        self.imbalance_weights = torch.ones_like(self.targets_bin)
        for i in range(2): 
            sum_0 = torch.sum(self.targets_bin[:, i] == 0).float()
            sum_1 = torch.sum(self.targets_bin[:, i] == 1).float()
            self.imbalance_weights[self.targets_bin[:, i] == 0, i] = sum_0/(sum_0 + sum_1)
            self.imbalance_weights[self.targets_bin[:, i] == 1, i] = sum_1/(sum_0 + sum_1)

        self.imbalance_weights = 1/self.imbalance_weights
         
    def get_bin_gc(self): 
        gc = torch.sum(self.targets_bin!=-1, dim=0)
        gc = gc/torch.sum(gc)
        gc = 1/gc
        return gc

    def count_pos_neg(self, targets, bias): 

        count_pos = [0, 0] 
        count_neg = [0, 0] 
        
        neg_idx = {0:[], 1:[]} 
        pos_idx = {0:[], 1:[]} 
        for idx, bi, tar in zip(range(len(bias)), bias, targets): 
            if tar == 1: 
                count_pos[int(bi)] += 1 
                pos_idx[int(bi)].append(idx)
            elif tar == 0: 
                count_neg[int(bi)] += 1
                neg_idx[int(bi)].append(idx) 
        
        return count_pos, count_neg, neg_idx, pos_idx 

    
    def calculate_bias_weights(self): 
        
        groups_counts = [0]*4 
        for t, b in zip(self.targets, self.biases): 
            idx = t + (b*2) 
            groups_counts[idx] += 1 

        groups_counts = np.array(groups_counts) 
        groups_counts = groups_counts/np.sum(groups_counts) 
        groups_counts = 1/groups_counts 
        groups_weights = np.zeros((len(self.targets)))
        
        for idx_, t, b in zip(range(len(self.targets)), self.targets, self.biases): 
            idx = t + (b*2)
            groups_weights[idx_] = groups_counts[idx] 

        self.groups_counts = groups_weights
    
    def under_sample_analysis(self, groups_samples): 
       
        groups_samples = groups_samples.split(',') 
        groups_samples = [int(x) for x in groups_samples] 
        self.set_groups_samples_ce(groups_samples)
        count_pos, count_neg, _, _ = self.count_pos_neg(
                                         self.targets, 
                                         self.biases) 
       
        print('Counts Positive: ', count_pos) 
        print('Counts Negative: ', count_neg) 
    
    def under_sample_ce(self): 
        
        self.targets = torch.clone(self.targets_) 
        self.biases = torch.clone(self.biases_) 
        self.indices = torch.clone(self.indices_) 
        self.groups_counts = torch.clone(self.groups_counts_)

        count_pos, count_neg, neg_idx, pos_idx = self.count_pos_neg(
                                                self.targets, 
                                                self.biases) 
        
        print('Counts Positive: ', count_pos) 
        print('Counts Negative: ', count_neg) 
                
        # count_pos = [1089, 1089]
        # count_neg = [1089, 1089]
        count_neg, count_pos = us_alg_ce(count_neg, count_pos)
        groups_samples = [] 
        
        groups_samples.extend(count_pos) 
        groups_samples.extend(count_neg) 
        
        self.set_groups_samples_ce(groups_samples)
        count_pos, count_neg, _, _ = self.count_pos_neg(
                                         self.targets, 
                                         self.biases) 
        print('After US: Counts Positive: ', count_pos) 
        print('After US: Counts Negative: ', count_neg) 

    def over_sample_ce(self, verbose=False): 
        self.targets = torch.clone(self.targets_) 
        self.biases = torch.clone(self.biases_) 
        self.groups_counts = torch.clone(self.groups_counts_) 
        self.targets_bin = torch.clone(self.targets_bin_)
        self.indices = torch.clone(self.indices_)

        count_pos, count_neg, neg_idx, pos_idx = self.count_pos_neg(
                                        self.targets, 
                                        self.biases) 

        total_counts = count_pos + count_neg
        max_count = np.max(total_counts)
        to_keep_samples = [] 
        for bias_idx in range(2): 
            to_keep_samples.extend(np.random.choice(pos_idx[bias_idx], max_count, replace=True))
            to_keep_samples.extend(np.random.choice(neg_idx[bias_idx], max_count, replace=True))

        self.targets = self.targets[to_keep_samples]
        self.biases = self.biases[to_keep_samples]
        self.groups_counts = self.groups_counts[to_keep_samples]
        self.indices = self.indices[to_keep_samples]
        self.targets_bin = self.targets_bin[to_keep_samples]

        count_pos, count_neg, neg_idx, pos_idx = self.count_pos_neg(
                                        self.targets, 
                                        self.biases) 
        
    def under_sample_paper(self):
        for target_idx in range(2):
            
            self.targets_bin[:, target_idx] = self.targets_bin_[:, target_idx] 
            count_pos, count_neg, neg_idx, pos_idx = self.count_pos_neg(
                                                    self.targets_bin[:, target_idx], 
                                                    self.biases) 
            
            print('Target %d'%target_idx)
            print('Counts Positive: ', count_pos) 
            print('Counts Negative: ', count_neg) 

            '''
            new_count_neg = us_alg_bin(count_neg, count_pos)
            groups_samples = [] 
            groups_samples.extend(count_pos) 
            groups_samples.extend(new_count_neg) 
            '''

            if target_idx == 1: 
                new_count_neg = us_alg_bin(count_neg, count_pos)
                groups_samples = [] 
                groups_samples.extend(count_pos) 
                groups_samples.extend(new_count_neg) 
            
            else: 
                new_count_pos = us_alg_bin(count_pos, count_neg)
                groups_samples = [] 
                groups_samples.extend(new_count_pos) 
                groups_samples.extend(count_neg) 
            
            self.set_groups_samples_bin(groups_samples, target_idx)
            count_pos, count_neg, _, _ = self.count_pos_neg(
                                             self.targets_bin[:, target_idx], 
                                             self.biases) 
            print('Counts Positive: ', count_pos) 
            print('Counts Negative: ', count_neg) 
    
    def under_sample_bin(self, verbose=True, diff = 0):
        for target_idx in range(2):
            
            self.targets_bin[:, target_idx] = self.targets_bin_[:, target_idx] 
            count_pos, count_neg, neg_idx, pos_idx = self.count_pos_neg(
                                                    self.targets_bin[:, target_idx], 
                                                    self.biases) 
           
            if verbose: 
                print('Target %d'%target_idx)
                print('Counts Positive: ', count_pos) 
                print('Counts Negative: ', count_neg) 

            
            new_count_neg = us_alg_bin(count_neg, count_pos, diff)
            groups_samples = [] 
                        
            groups_samples.extend(count_pos) 
            groups_samples.extend(new_count_neg) 
            
            self.set_groups_samples_bin(groups_samples, target_idx)
            count_pos, count_neg, _, _ = self.count_pos_neg(
                                             self.targets_bin[:, target_idx], 
                                             self.biases) 
            if verbose:
                print('After US: Counts Positive: ', count_pos) 
                print('After US: Counts Negative: ', count_neg) 

    def set_groups_samples_ce(self, groups_samples): 
        count_pos, count_neg, neg_idx, pos_idx = self.count_pos_neg(
                                                        self.targets, 
                                                        self.biases) 
        all_idx = [] 
        all_idx.append(pos_idx[0]) 
        all_idx.append(pos_idx[1])
        all_idx.append(neg_idx[0])
        all_idx.append(neg_idx[1]) 
    
        to_keep_idx = [] 
        for group_idx, group_sample_num in zip(all_idx, groups_samples):
            to_keep_idx.extend(random.sample(group_idx, group_sample_num)) 
        
        self.targets = self.targets[to_keep_idx] 
        self.biases = self.biases[to_keep_idx] 
        self.indices = self.indices[to_keep_idx] 
        self.groups_counts = self.groups_counts[to_keep_idx] 
  
    def set_groups_samples_bin(self, groups_samples, target_idx): 
        count_pos, count_neg, neg_idx, pos_idx = self.count_pos_neg(
                                                        self.targets_bin[:, target_idx], 
                                                        self.biases) 
        all_idx = [] 
        all_idx.append(pos_idx[0]) 
        all_idx.append(pos_idx[1])
        all_idx.append(neg_idx[0])
        all_idx.append(neg_idx[1]) 

        to_keep_idx = [] 
        for group_idx, group_sample_num in zip(all_idx, groups_samples):
            to_keep_idx.extend(random.sample(group_idx, group_sample_num)) 
        
        full_idxs = torch.arange((len(self.targets)))
        to_select = torch.ones((len(self.targets)))
        to_select[to_keep_idx] = 0 
        full_idxs = full_idxs[to_select.bool()] 

        self.targets_bin[full_idxs, target_idx] = -1
           
    def set_min_idxs(self): 

        self.min_idxs = []
        count_pos, count_neg, neg_idx, pos_idx = self.count_pos_neg(
                                                self.targets, 
                                                self.biases) 
        self.min_idxs.append(np.argmin(count_neg))
        self.min_idxs.append(np.argmin(count_pos))
    
    def build_blonde(self):
        biases = self.celeba.attr[:, self.bias_idx]
        targets = self.celeba.attr[:, self.target_idx]
        selects = torch.arange(len(self.celeba))[(biases == 0) & (targets == 0)]
        non_selects = torch.arange(len(self.celeba))[~((biases == 0) & (targets == 0))]
        np.random.shuffle(selects)
        indices = torch.cat([selects[:2000], non_selects])
        return indices

    def __getitem__(self, index):
        img, _ = self.celeba.__getitem__(self.indices[index])
        target, bias = self.targets[index], self.biases[index]
        target_bin = self.targets_bin[index]
        gc = self.groups_counts[index] 
        gc_imbalance = self.imbalance_weights[index]
        
        return img, target_bin, bias, index, target, gc, gc_imbalance

    def __len__(self):
        return len(self.targets)


def get_celeba(root, batch_size, target_attr='blonde', split='train', num_workers=2, aug=True, two_crop=False, ratio=0,
               img_size=224, given_y=True, under_sample=None, resample_blonde=True, diff_analysis = 0.0):
    logging.info(f'get_celeba - split:{split}, aug: {aug}, given_y: {given_y}, ratio: {ratio}')
    if split == 'eval':
        transform = T.Compose(
            [
                T.Resize((img_size, img_size)),
                T.ToTensor(),
                T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
            ]
        )
    else:
        if aug:
            transform = T.Compose([
                T.RandomResizedCrop(size=img_size, scale=(0.2, 1.)),
                T.RandomHorizontalFlip(),
                T.RandomApply([
                    T.ColorJitter(0.4, 0.4, 0.4, 0.1)
                ], p=0.8),
                T.RandomGrayscale(p=0.2),
                T.ToTensor(),
                T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
            ])

        else:
            transform = T.Compose(
                [
                    T.Resize((img_size, img_size)),
                    T.RandomHorizontalFlip(),
                    T.ToTensor(),
                    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
                ]
            )

    if two_crop:
        transform = TwoCropTransform(transform)

    dataset = BiasedCelebASplit(
        root=root,
        split=split,
        transform=transform,
        target_attr=target_attr,
        under_sample=under_sample,
        resample_blonde=resample_blonde,
        diff_analysis = diff_analysis,
    )
    

    def clip_max_ratio(score):
        upper_bd = score.min() * ratio
        return np.clip(score, None, upper_bd)
    
    #if groups_samples: 
    #    dataset.set_groups_samples(groups_samples)

    if ratio != 0:
        if given_y:
            weights = [1 / dataset.confusion_matrix_by[c, b] for c, b in zip(dataset.targets, dataset.biases)]
        else:
            weights = [1 / dataset.confusion_matrix[b, c] for c, b in zip(dataset.targets, dataset.biases)]
        if ratio > 0:
            weights = clip_max_ratio(np.array(weights))
        sampler = WeightedRandomSampler(weights, len(weights), replacement=True)
    else:
        sampler = None

    dataloader = DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        shuffle=True if sampler is None else False,
        sampler=sampler,
        num_workers=num_workers,
        pin_memory=True,
        drop_last=two_crop
    )
    return dataloader
