import numpy as np
import os
import torch
import pandas as pd
import requests
from torch.utils.data import Dataset
import sklearn.preprocessing as preprocessing
from module import check_exists, makedir_exist_ok, save, load


class Community(Dataset):
    data_name = 'Community'

    def __init__(self, root, split, seed):
        self.root = os.path.expanduser(root)
        self.seed = seed
        self.split = split
        if not check_exists(self.processed_folder):
            self.process()
        self.id, self.data, self.target, self.sensitive = load(os.path.join(self.processed_folder, self.split))
        self.other = {}
        self.metadata = load(os.path.join(self.processed_folder, 'meta'))
    
    def __getitem__(self, index):
        id, data, target, sensitive = torch.tensor(self.id[index]), torch.tensor(self.data[index]), torch.tensor(
            self.target[index]), torch.tensor(self.sensitive[index])
        input = {'id': id, 'data': data, 'target': target, 'sensitive': sensitive}
        other = {k: torch.tensor(self.other[k][index]) for k in self.other}
        input = {**input, **other}
        return input
    
    def __len__(self):
        return len(self.data)
    

    def __repr__(self):
        fmt_str = 'Dataset {}\nSize: {}\nRoot: {}\nSeed: {}\nSplit: {}\nNGroup: {}'.format(self.__class__.__name__, self.__len__(),
                                                                     self.root,
                                                                     self.seed,
                                                                     self.split,
                                                                     self.metadata['n_groups'])
        return fmt_str

    @property
    def processed_folder(self):
        return os.path.join(self.root, 'processed', f'seed_{self.seed}')

    @property
    def raw_folder(self):
        return os.path.join(self.root, 'raw')

    def process(self):
        if not check_exists(self.raw_folder):
            self.download()
        train_set, test_set, meta = self.make_data()
        save(train_set, os.path.join(self.processed_folder, 'train'))
        save(test_set, os.path.join(self.processed_folder, 'test'))
        save(meta, os.path.join(self.processed_folder, 'meta'))
        return

    def download(self):
        makedir_exist_ok(self.raw_folder)
        # fetch dataset 
        url = 'https://raw.githubusercontent.com/steven7woo/fair_regression_reduction/refs/heads/master/data/communities.csv'
        res = requests.get(url)
        filename = os.path.basename(res.url)
        full_path = os.path.join(self.raw_folder, filename)
        if res.status_code == 200:
            # Open a local file in binary write mode and write the content from the response
            with open(full_path, "wb") as file:
                file.write(res.content)
            print("File downloaded successfully.")

        return
    
    
    def make_data(self):
        """
        Extract black and white dominant communities; 
        sub_size : number of communities for each group
        """
        target_outlier = np.array([
            0.47, 0.42, 0.56, 0.36, 0.43, 0.22, 1.0, 0.28, 0.46, 0.14, 0.35, 0.51, 
            0.65, 0.5, 0.28, 0.27, 0.34, 0.28, 0.16, 0.26, 0.25, 0.3, 0.21, 0.45, 
            0.62, 0.69, 0.18, 0.72, 0.41, 0.59, 0.39, 0.6, 0.25, 0.56, 0.69, 0.59, 
            0.6, 0.6, 0.21, 0.2, 0.22, 0.21, 0.47, 0.45, 0.29, 0.97, 0.14, 0.35, 
            0.43, 0.49, 0.53, 0.36, 0.39, 0.4, 0.4, 0.37, 0.49, 0.46, 0.38, 0.47, 
            0.47, 0.51, 0.27, 0.31, 0.74, 0.0, 0.26, 0.65, 0.3, 1.0, 0.58, 0.19, 
            0.52, 0.41, 0.37, 0.36, 0.32, 0.4, 0.45, 0.5, 0.5, 0.72, 0.77, 0.94, 
            0.12, 0.03, 0.43, 0.51, 0.51, 0.81, 0.66, 0.07, 0.28, 0.95, 0.19, 0.08, 
            0.22, 0.26, 0.29, 0.75, 0.52, 0.37, 0.4, 0.0, 0.53, 0.09, 0.93, 0.01, 
            0.04, 0.74, 0.36, 0.1, 0.06, 0.0, 0.0, 0.63, 0.21, 0.0
        ])

        df = pd.read_csv(os.path.join(self.raw_folder, 'communities.csv'),na_values='?', index_col=0)
        df = df.sample(frac=1, random_state=self.seed)
        df = df.fillna(0)
        df.reset_index(drop=True, inplace=True)
        B = "racepctblack"
        W = "racePctWhite"
        A = "racePctAsian"
        H = "racePctHisp"
        df_sens = df[[B, W, A, H]]
        
        
        # normalize violetCrimes to [0,1]
        df['ViolentCrimesPerPop'] = (df['ViolentCrimesPerPop'] - df['ViolentCrimesPerPop'].min()) / (df['ViolentCrimesPerPop'].max() - df['ViolentCrimesPerPop'].min())
        # creating labels using crime rate
        target = df['ViolentCrimesPerPop'].to_numpy().reshape(-1,1)
        df = df.drop('ViolentCrimesPerPop', axis=1)
        maj = df_sens.apply(pd.Series.idxmax, axis=1)
    
        # remap the values of maj
        sensitive = maj.map({B : 0, W : 1, A : 0, H : 0}).to_numpy()
        df['race'] = sensitive
        df = df.drop(H, axis=1)
        df = df.drop(B, axis=1)
        df = df.drop(W, axis=1)
        df = df.drop(A, axis=1)
        data = df.to_numpy()
        
        # delete an outlier record
        matches = np.all(data == target_outlier, axis=1)
        indices = np.where(matches)[0]
        data = np.delete(data, indices, axis=0)
        sensitive = np.delete(sensitive, indices)
        target = np.delete(target, indices).reshape(-1,1)
        
        split_idx = int(0.8 * len(data))

        train_data, test_data = data[:split_idx].astype(np.float32), data[split_idx:].astype(np.float32)

        # get sensitive feature
        train_sensitive = sensitive[:split_idx].astype(np.int64)
        test_sensitive = sensitive[split_idx:].astype(np.int64)

        train_target, test_target = target[:split_idx].astype(np.float32), target[split_idx:].astype(np.float32)
        
        train_id, test_id = np.arange(len(train_data)).astype(np.int64), np.arange(len(test_data)).astype(np.int64)
        n_groups = len(np.unique(sensitive))
        self.metadata = {'n_groups': n_groups, 'n_classes': 1}

        return ((train_id, train_data, train_target, train_sensitive),
                (test_id, test_data, test_target, test_sensitive),
               self.metadata)

