import os
import torch
import numpy as np
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import pandas as pd
import random
from tqdm import tqdm
import pickle
import shutil
import argparse
import sys
import csv
from collections import namedtuple

np.set_printoptions(threshold=sys.maxsize)
np.set_printoptions(suppress=True)

SMALL_DOGS = "small_dogs"
BIG_DOGS = "big_dogs"
INDOOR = "indoor"
OUTDOOR = "outdoor"

LANDBIRDS = "landbirds"
WATERBIRDS = "waterbirds"
LAND = "land"
WATER = "water"

MAJORITY_SIZE = {
    "train": 10000,
    "val": 500,
    "test": 500,
}
MINORITY_SIZE = {
    "train": 500,
    "val": 25,
    "test": 500,
}

GROUPS = [(0, 0), (0, 1), (1, 0), (1, 1), (2, 2), (2, 3), (3, 2), (3, 3)]

class CelebADataset(Dataset):
    def __init__(self, root_dir, split='train', transform=None, opt=None):

        self.bias_rate = 0.
        self.root_dir = root_dir
        self.split = split
        self.transform = transform
        self.opt = opt

        self.target_attr = "gender"
        self.bias_attr = "age"

        self.class_to_name = {
            0:"landbird", 
            1:"waterbird",
            2:"small_dog", 
            3:"big_dog"
        }

        self.bias_to_name = { 
            0:"Land", 
            1:"Water", 
            2:"Indoors", 
            3:"Outdoors",
        }

        self.targets_values = range(4)
        self.biases_values = range(4)


        self.gen_dir = f"./generated_images"        
    
        if os.path.exists(f"data/spuco_animals_{split}_{self.opt.minority_to_keep}.pkl"):
            data = pickle.load(open(f"data/spuco_animals_{split}_{self.opt.minority_to_keep}.pkl", "rb"))
            self.filenames = data['filenames']
            self.targets = data['targets']
            self.biases = data['biases']
            self.real_gen = data['real_gen']


        else: 
        
            self.load_data()
            self.set_group_counts()

            if split == "train":
                self.make_harder(500, 10000, False)
                self.set_group_counts()

            if split == 'val':
                self.make_harder(25, 500)
                self.set_group_counts()

            # save filenames targets biases into a pickle file
            data = {
                'filenames': self.filenames,
                'targets': self.targets,
                'biases': self.biases,
                'real_gen': self.real_gen
            }
            pickle.dump(data, open(f"data/spuco_animals_{split}_{self.opt.minority_to_keep}.pkl", "wb"))

        self.set_group_counts()
        self.create_real_gen_weights()
        # self.get_class_distribution()

    def get_gen_ratio(self): 

        num_gen = np.sum(self.real_gen == 1)
        num_real = np.sum(self.real_gen == 0)

        print("Ratio REAL/GEN: ", num_gen/num_real)

    def upweight_misclassified_samples(self, alpha, misclassified_filenames): 
        idx_misclassified = [self.filenames.index(fname) for fname in misclassified_filenames]
        #repeat the misclassified samples alpha times
        self.targets_misclassified = self.targets[idx_misclassified]
        self.biases_misclassified = self.biases[idx_misclassified]
        self.real_gen_misclassified = self.real_gen[idx_misclassified]

        #repeat targets_misclassified alpha times
        self.targets = np.concatenate((self.targets, np.tile(self.targets_misclassified,alpha)))
        self.biases = np.concatenate((self.biases, np.tile(self.biases_misclassified, alpha)))
        self.real_gen = np.concatenate((self.real_gen, np.tile(self.real_gen_misclassified, alpha)))
        self.filenames.extend([self.filenames[i] for i in idx_misclassified for _ in range(alpha)])

        self.create_real_gen_weights()

    def cut_data(self, data_amount): 
        idxs = [] 
        for idx, (target_value, bias_value) in enumerate(GROUPS): 
            idx_ = np.where((self.targets == target_value) & (self.biases == bias_value))[0]
            
            start_idx = 0
            end_idx = data_amount[idx]

            idxs.extend(idx_[start_idx:end_idx])

        self.filenames = [self.filenames[i] for i in idxs]
        self.targets = self.targets[idxs]
        self.biases = self.biases[idxs]
        self.real_gen = self.real_gen[idxs]

    def get_remove_idx(self, min_bias, counts_class, class_idx): 
            max_bias = 1 - min_bias 
        
            ratio_minority = 1 - self.opt.minority_to_keep 
            min_bias_new_count = int((ratio_minority * counts_class[max_bias])/(1-ratio_minority))
            to_remove_samples_num = int(counts_class[min_bias] - min_bias_new_count)

            idx_to_remove = np.where((self.targets == class_idx) & (self.biases == min_bias))[0]
            idx_to_remove = np.random.choice(idx_to_remove, to_remove_samples_num, replace=False)

            return idx_to_remove

    def make_harder(self, min_group_count, max_group_count, keep_min = True): 

        to_remove_samples_total = [] 

        for class_idx in self.targets_values: 
            for idx, (class_idx_group, bias_idx_group) in enumerate(GROUPS): 
                
                if class_idx_group != class_idx: 
                    continue
                
                count_group = self.group_counts_[idx] 
                if count_group > min_group_count: 
                    continue
                
                # to_keep_ratio = 1 - self.opt.minority_to_keep
                # to_keep = int(to_keep_ratio * count_group)

                ratio = 1 - self.opt.minority_to_keep
                to_keep = (max_group_count*ratio)/(1 - ratio)

                if keep_min: 
                    to_keep = max(to_keep, 15)
                to_remove_samples_num = int(count_group - to_keep)
                
                idx_to_remove =  np.where((self.targets == class_idx_group) & (self.biases == bias_idx_group))[0]
                idx_to_remove = np.random.choice(idx_to_remove, to_remove_samples_num, replace=False)

                to_remove_samples_total.extend(idx_to_remove)

        self.targets = np.delete(self.targets, to_remove_samples_total)
        self.biases = np.delete(self.biases, to_remove_samples_total)
        self.filenames = list(np.delete(np.array(self.filenames), to_remove_samples_total))

        self.real_gen = np.delete(self.real_gen, to_remove_samples_total)

    def load_gen_data_balance(self):

        to_add_all_filenames = []
        to_add_all_targets = [] 
        to_add_all_biases = [] 
        to_add_all_real_gen = [] 

        max_count = np.max(self.group_counts_) 
        for idx, (target_value, bias_value) in enumerate(GROUPS): 

                to_add = int(max_count - self.group_counts_[idx])
                to_load_path = os.path.join(self.gen_dir, f"{self.class_to_name[target_value]}_{self.bias_to_name[bias_value]}")

                filenames = os.listdir(to_load_path)
                filenames = [os.path.join(to_load_path, filename) for filename in filenames]
                filenames = filenames[:to_add]

                to_add_all_targets.extend([target_value] * len(filenames))
                to_add_all_biases.extend([bias_value] * len(filenames))
                to_add_all_real_gen.extend([1] * len(filenames))
                to_add_all_filenames.extend(filenames)

        if self.opt.limit_to_gen: 
            self.filenames = to_add_all_filenames
            self.targets = np.array(to_add_all_targets)
            self.biases = np.array(to_add_all_biases)
            self.real_gen = np.array(to_add_all_real_gen)

        else: 
            self.filenames.extend(to_add_all_filenames)
            self.targets = np.concatenate((self.targets, np.array(to_add_all_targets)), axis=0)
            self.biases = np.concatenate((self.biases, np.array(to_add_all_biases)), axis=0)
            self.real_gen = np.concatenate((self.real_gen, np.array(to_add_all_real_gen)), axis=0)

        self.set_group_counts()
        self.create_real_gen_weights()

    def get_classes(self): 
        return self.targets_values

    def get_biases(self): 
        return self.biases_values

    def num_heads_di_real_gen(self): 
        return len(np.unique(self.bias_real_gen))

    def num_heads_di(self): 
        return len(np.unique(self.biases))

    def group_counts_dro(self): 
        group_counts = [] 
        for idx in np.unique(self.group_idx): 
            group_counts.append(np.sum(self.group_idx == idx))
        return group_counts

    def group_counts_dro_real_gen(self):
        group_counts = [] 
        for idx in np.unique(self.group_idx_real_gen): 
            group_counts.append(np.sum(self.group_idx_real_gen == idx))
        return group_counts

    def n_groups_dro(self): 
        return len(np.unique(self.group_idx))

    def n_groups_dro_real_gen(self):
        return len(np.unique(self.group_idx_real_gen))

    def load_dogs(self): 

        filenames = [] 
        targets = []
        biases = []

        root_dir_birds =  os.path.join(self.root_dir, "spuco_dogs", self.split)

        small_dogs_indoor = os.listdir(os.path.join(root_dir_birds, f"{SMALL_DOGS}/{INDOOR}"))
        filenames.extend([str(os.path.join(root_dir_birds, f"{SMALL_DOGS}/{INDOOR}", x)) for x in small_dogs_indoor])
        targets.extend([0] * len(small_dogs_indoor))
        biases.extend([0] * len(small_dogs_indoor))
        assert len(small_dogs_indoor) == MAJORITY_SIZE[self.split], f"Dataset corrupted or missing files. Expected {MAJORITY_SIZE[self.split]} files got {len(small_dogs_indoor)}"
        
        # Small Dogs - Outdoor
        small_dogs_outdoor = os.listdir(os.path.join(root_dir_birds, f"{SMALL_DOGS}/{OUTDOOR}"))
        filenames.extend([str(os.path.join(root_dir_birds, f"{SMALL_DOGS}/{OUTDOOR}", x)) for x in small_dogs_outdoor])
        targets.extend([0] * len(small_dogs_outdoor))
        biases.extend([1] * len(small_dogs_outdoor))   
        assert len(small_dogs_outdoor) == MINORITY_SIZE[self.split], f"Dataset corrupted or missing files. Expected {MINORITY_SIZE[self.split]} files got {len(small_dogs_outdoor)}"
        
        # Big Dogs - Indoor
        big_dogs_indoor = os.listdir(os.path.join(root_dir_birds, f"{BIG_DOGS}/{INDOOR}"))
        filenames.extend([str(os.path.join(root_dir_birds, f"{BIG_DOGS}/{INDOOR}", x)) for x in big_dogs_indoor])
        targets.extend([1] * len(big_dogs_indoor))
        biases.extend([0] * len(big_dogs_indoor))
        assert len(big_dogs_indoor) == MINORITY_SIZE[self.split], f"Dataset corrupted or missing files. Expected {MINORITY_SIZE[self.split]} files got {len(big_dogs_indoor)}"
        
        # Big Dogs - Outdoor
        big_dogs_outdoor = os.listdir(os.path.join(root_dir_birds, f"{BIG_DOGS}/{OUTDOOR}"))
        filenames.extend([str(os.path.join(root_dir_birds, f"{BIG_DOGS}/{OUTDOOR}", x)) for x in big_dogs_outdoor])
        targets.extend([1] * len(big_dogs_outdoor))
        biases.extend([1] * len(big_dogs_outdoor)) 
        assert len(big_dogs_outdoor) == MAJORITY_SIZE[self.split], f"Dataset corrupted or missing files. Expected {MAJORITY_SIZE[self.split]} files got {len(big_dogs_outdoor)}"

        return filenames, targets, biases

    def load_birds(self): 

        filenames = [] 
        targets = []
        biases = []

        root_dir_birds =  os.path.join(self.root_dir, "spuco_birds", self.split)


        landbirds_land = os.listdir(os.path.join(root_dir_birds, f"{LANDBIRDS}/{LAND}"))
        filenames.extend([str(os.path.join(root_dir_birds, f"{LANDBIRDS}/{LAND}", x)) for x in landbirds_land])
        targets.extend([0] * len(landbirds_land))
        biases.extend([0] * len(landbirds_land))
        assert len(landbirds_land) == MAJORITY_SIZE[self.split], f"Dataset corrupted or missing files [landbirds_land]. Expected {MAJORITY_SIZE[self.split]} files got {len(landbirds_land)}"
        
        # Landbirds Water 
        landbirds_water = os.listdir(os.path.join(root_dir_birds, f"{LANDBIRDS}/{WATER}"))
        filenames.extend([str(os.path.join(root_dir_birds, f"{LANDBIRDS}/{WATER}", x)) for x in landbirds_water])
        targets.extend([0] * len(landbirds_water))
        biases.extend([1] * len(landbirds_water))   
        assert len(landbirds_water) == MINORITY_SIZE[self.split], f"Dataset corrupted or missing files [landbirds_water]. Expected {MINORITY_SIZE[self.split]} files got {len(landbirds_water)}"
        
        # Waterbirds Land
        waterbirds_land = os.listdir(os.path.join(root_dir_birds, f"{WATERBIRDS}/{LAND}"))
        filenames.extend([str(os.path.join(root_dir_birds, f"{WATERBIRDS}/{LAND}", x)) for x in waterbirds_land])
        targets.extend([1] * len(waterbirds_land))
        biases.extend([0] * len(waterbirds_land))
        assert len(waterbirds_land) == MINORITY_SIZE[self.split], f"Dataset corrupted or missing files [waterbirds_land]. Expected {MINORITY_SIZE[self.split]} files got {len(waterbirds_land)}"
        
        # Waterbirds Water
        waterbirds_water = os.listdir(os.path.join(root_dir_birds, f"{WATERBIRDS}/{WATER}"))
        filenames.extend([str(os.path.join(root_dir_birds, f"{WATERBIRDS}/{WATER}", x)) for x in waterbirds_water])
        targets.extend([1] * len(waterbirds_water))
        biases.extend([1] * len(waterbirds_water)) 
        assert len(waterbirds_water) == MAJORITY_SIZE[self.split], f"Dataset corrupted or missing files [waterbirds_water]. Expected {MAJORITY_SIZE[self.split]} files got {len(waterbirds_water)}"

        return filenames, targets, biases

    def load_data(self): 

        self.filenames = []
        self.targets = []
        self.biases = []
        
        dogs_filenames, dogs_targets, dogs_biases = self.load_dogs()
        birds_filenames, birds_targets, birds_biases = self.load_birds()

        self.filenames.extend(birds_filenames)
        self.targets.extend(birds_targets)
        self.biases.extend(birds_biases)

        self.filenames.extend(dogs_filenames)
        self.targets.extend([label + 2 for label in dogs_targets])
        self.biases.extend([bias + 2 for bias in dogs_biases])

        self.filenames = np.array(self.filenames)
        self.targets = np.array(self.targets)
        self.biases = np.array(self.biases)
        self.real_gen = np.zeros_like(self.targets)


    def get_class_distribution(self):
        for class_value_, bias_value in GROUPS: 

            class_index = self.targets == class_value_
            bias_idx = self.biases == bias_value

            num_images =np.sum(class_index & bias_idx)
            #print(f'Class Value {class_value_}, Bias {bias_value}:  {num_images / np.sum(class_index):.4f}')
            print(f'Class Value {class_value_}, Bias {bias_value}:  {num_images}')

        print("====================================")

    def set_group_counts(self): 
        self.group_counts_ = np.zeros((len(GROUPS)))
        for idx, (class_value, bias_value) in enumerate(GROUPS): 
        
            class_index = self.targets == class_value
            bias_idx = self.biases == bias_value

            self.group_counts_[idx] = np.sum(class_index & bias_idx)

    def create_real_gen_weights(self, quit_ = False): 

        self.real_weights_groups = np.zeros_like(self.targets, dtype=np.float32)

        probs = self.group_counts_/np.sum(self.group_counts_)
        probs[probs == 0] = 1e-10
        probs = 1/probs

        for idx, (class_value, bias_value) in enumerate(GROUPS): 
                
            class_index = self.targets == class_value
            bias_idx = self.biases == bias_value

            self.real_weights_groups[class_index & bias_idx] = probs[idx]

        self.group_idx = np.zeros_like(self.targets, dtype=np.int32) 
        
        class_bias_pairs = [] 
        for class_label, bias_label in zip(self.targets, self.biases):
            if (class_label, bias_label) not in class_bias_pairs:
                class_bias_pairs.append((class_label, bias_label))
        
        for idx, class_bias_pair in enumerate(class_bias_pairs):
            class_value, bias_value = class_bias_pair

            class_index = self.targets == class_value
            bias_idx = self.biases == bias_value

            self.group_idx[class_index & bias_idx] = idx

        self.group_idx_real_gen = np.zeros_like(self.targets, dtype=np.int32) 
        class_bias_real_gen_pairs = []
        for class_label, bias_label, real_gen_label in zip(self.targets, self.biases, self.real_gen):
            if (class_label, bias_label, real_gen_label) not in class_bias_real_gen_pairs:
                class_bias_real_gen_pairs.append((class_label, bias_label, real_gen_label))

        for idx, class_bias_real_gen_pair in enumerate(class_bias_real_gen_pairs):
            class_value, bias_value, real_gen_value = class_bias_real_gen_pair

            class_index = self.targets == class_value
            bias_idx = self.biases == bias_value
            real_gen_idx = self.real_gen == real_gen_value

            self.group_idx_real_gen[class_index & bias_idx & real_gen_idx] = idx

        self.bias_real_gen = np.zeros_like(self.targets, dtype=np.int32)
        bias_real_gen_pairs = []

        for bias_label, real_gen_label in zip(self.biases, self.real_gen):
            if (bias_label, real_gen_label) not in bias_real_gen_pairs:
                bias_real_gen_pairs.append((bias_label, real_gen_label))
        
        for idx, bias_real_gen_pair in enumerate(bias_real_gen_pairs):
            bias_value, real_gen_value = bias_real_gen_pair

            bias_idx = self.biases == bias_value
            real_gen_idx = self.real_gen == real_gen_value

            self.bias_real_gen[bias_idx & real_gen_idx] = idx


    def load_gen_data(self):

        filenames_all = []
        targets_all = []
        biases_all = [] 

        for idx, (target_value, bias_value) in enumerate(GROUPS): 
            to_load_path = os.path.join(self.gen_dir, f"{self.class_to_name[target_value]}_{self.bias_to_name[bias_value]}")
            
            filenames = os.listdir(to_load_path)
            filenames = [os.path.join(to_load_path, filename) for filename in filenames]

            start_idx = 0
            end_idx = self.opt.num_per_group 
            
            filenames_all.extend(filenames[start_idx:end_idx])
            targets_all.extend([target_value] * len(filenames[start_idx:end_idx]))
            biases_all.extend([bias_value] * len(filenames[start_idx:end_idx]))

        self.filenames.extend(filenames_all)
        self.targets = np.concatenate((self.targets, np.array(targets_all)), axis=0)
        self.biases = np.concatenate((self.biases, np.array(biases_all)), axis=0)
        self.real_gen = np.concatenate((self.real_gen, np.ones(len(targets_all))), axis=0)

        self.set_group_counts()
        self.create_real_gen_weights()


    def load_image(self, filename):
        img = Image.open(filename)
        img = img.convert('RGB')
        return img

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        filename = self.filenames[idx]
        target = self.targets[idx]
        bias = self.biases[idx]
        idx_img = filename.split('/')[-1].split('.')[0]
        real_gen = self.real_gen[idx]
        img = self.load_image(filename)
        group_idx = self.group_idx[idx]
        real_gen_weights = self.real_weights_groups[idx]
        group_idx_real_gen = self.group_idx_real_gen[idx]
        bias_real_gen = self.bias_real_gen[idx]

        if self.transform:
            img = self.transform(img)


        data = {
            'idx': idx,
            'img': img,
            'target': target,
            'bias': bias,
            'filenames': filename, 
            'real_gen': real_gen,
            'group_idx':group_idx,
            'real_weights_groups': real_gen_weights,
            'group_idx_real_gen': group_idx_real_gen,
            'bias_real_gen': bias_real_gen

        }

        return data

