import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from wilds import get_dataset
from wilds.common.data_loaders import get_train_loader, get_eval_loader
from wilds.common.grouper import CombinatorialGrouper
from expl_dist_shift.graph_utils import verify_cgm
import torchvision.transforms as transforms
import torch
from tqdm import tqdm
import timm
from sklearn.preprocessing import LabelEncoder
from imblearn.over_sampling import RandomOverSampler
from torch.utils.data import Subset
from causalgraphicalmodels import CausalGraphicalModel
from pathlib import Path
from expl_dist_shift.data import metashift

class WildsDataset():
    TARGET_NAME = 'Y'
    
    TASK_TYPE = 'classification'
    GRAPH = verify_cgm(CausalGraphicalModel(
        nodes = ['G', 'X', 'Y'],
        edges = [ ('G', 'Y'),
            ('G', 'X'),
            ('Y', 'X')]
    ))

    def __init__(self, dataset_name, hparams, device, cache_embs = True):
        self.device = device
        self.hparams = hparams 
        self.data_seed = hparams['data_seed']
        self.dataset_name = dataset_name        
        if 'emb_model' in hparams:
            self.cnn_model = hparams['emb_model']
        else:
            self.cnn_model = 'resnet50'        
        cache_path = Path(hparams['data_dir'])/'cache'/dataset_name

        if self.cnn_model == 'resnet50':
            resnet_n_features = 2048
        elif self.cnn_model == 'resnet18':
            resnet_n_features = 512

        self.VAR_CATEGORIES = {
            'G': ['G'],
            'X': [f'X{i}' for i in range(resnet_n_features)],
            'Y': 'Y'
        }
        self.TRAIN_FEATURES = [f'X{i}' for i in range(resnet_n_features)]        

        if (cache_path/f'train_df_{self.cnn_model}.pkl').is_file():
            self.train_df = pd.read_pickle(cache_path/f'train_df_{self.cnn_model}.pkl')
            self.test_df = pd.read_pickle(cache_path/f'test_df_{self.cnn_model}.pkl')
        else:
            self.m = timm.create_model(self.cnn_model, pretrained=True, num_classes=0).to(self.device).eval()
            if dataset_name in ['waterbirds', 'celebA']:
                self.dataset = get_dataset(dataset=dataset_name, download=True, root_dir = hparams['data_dir'])
            elif dataset_name == 'metashift':
                self.dataset = metashift.MetaShiftCatsDogsDataset(root_dir = hparams['data_dir'], test_pct = 0.25, val_pct = 0.0,
                    data_seed = self.data_seed)
            train_data = self.dataset.get_subset(
                "train",
                transform=transforms.Compose(
                    [transforms.Resize((224, 224)), transforms.ToTensor()]
                )
            )
            train_loader = get_eval_loader("standard", train_data, batch_size=256)

            test_data = self.dataset.get_subset(
                "test",
                transform=transforms.Compose(
                    [transforms.Resize((224, 224)), transforms.ToTensor()]
                )
            )
            test_loader = get_eval_loader("standard", test_data, batch_size=256)

            train_embs, train_y, train_g = self.get_embs(train_loader)
            test_embs, test_y, test_g = self.get_embs(test_loader)

            assert train_embs.shape[1] == resnet_n_features

            self.train_df = self.data_to_df(train_embs, train_y, train_g)
            self.test_df = self.data_to_df(test_embs, test_y, test_g)

            if cache_embs:
                cache_path.mkdir(parents = True, exist_ok = True)
                self.train_df.to_pickle(cache_path/f'train_df_{self.cnn_model}.pkl')
                self.test_df.to_pickle(cache_path/f'test_df_{self.cnn_model}.pkl')

        # regenerate training and test to come from same distribution
        self.train_df, self.test_df = train_test_split(self.train_df, test_size = hparams['test_pct'], random_state = self.data_seed)

    def data_to_df(self, embs, y, g):
        df = pd.DataFrame({'Y': y, 'G': g})
        for i in range(embs.shape[-1]):
            df[f'X{i}'] = embs[:, i]
        return df

    def get_embs(self, loader):
        embs, y, g = [], [], []
        with torch.no_grad():
            for x, y_true, metadata in tqdm(loader):
                x = x.to(self.device)
                g.append(metadata[:, 0].numpy())
                y.append(y_true.numpy())
                embs.append(self.m(x).detach().cpu().numpy())
        return np.concatenate(embs), np.concatenate(y), np.concatenate(g)

    def balance_groups(self, df, random_seed):
        new_df, new_g = RandomOverSampler(random_state = random_seed).fit_resample(
                 df, df['G'])
        return new_df

    def get_source_train_test(self):
        return self.balance_groups(self.train_df, self.data_seed), self.balance_groups(self.test_df, self.data_seed)

    # def get_source_train_test(self):
    #     return self.train_df, self.test_df

    # def oversample(self, df, ratio, random_seed):
    #     new_dfs = []
    #     for g in df.G.unique():
    #         sub_df = df[df.G == g]
    #         counts = sub_df['Y'].value_counts()
    #         majority_y = counts.index[0]
    #         new_df_maj = sub_df[(sub_df.Y == majority_y)].sample(replace = False, frac = ratio, random_state = random_seed)
    #         new_dfs.append(new_df_maj)
    #         new_dfs.append(sub_df[sub_df.Y != majority_y])
    #     new_df = pd.concat(new_dfs, ignore_index = True).reset_index(drop = True)  
    #     return new_df

    def oversample(self, df, ratio, random_seed):
        df = self.balance_groups(df, random_seed)
        counts = df.apply(lambda x: (x['G'], x['Y']), axis = 1).value_counts()
        new_dfs = []
        for c, (idx, val) in enumerate(counts.iteritems()):
            g, y = idx[0], idx[1]
            if c <= 1: # undersample two majority groups
                new_df_i = df[(df.G == g) & (df.Y == y)].sample(replace = False, frac = ratio, random_state = random_seed)
            else: # oversample two minority groups
                # new_df_i = df[(df.G == g) & (df.Y == y)]
                new_df_i = df[(df.G == g) & (df.Y == y)].sample(replace = True, frac = 1/ratio, random_state = random_seed)
            new_dfs.append(new_df_i)
        new_df = pd.concat(new_dfs, ignore_index = True).reset_index(drop = True)

        # enc_groups = LabelEncoder().fit_transform(df[classes].apply(lambda x: '_'.join(map(str, x.values)), axis = 1))
        # counts = pd.Series(enc_groups).value_counts().sort_values(ascending = False)
        # des_samples = int(counts.max()*ratio)
        # des_dict = {i: int(counts[i]) if counts[i] >= des_samples else des_samples for i in counts.index}
        # new_df, new_y = RandomOverSampler(sampling_strategy = des_dict, random_state = random_seed).fit_resample(
        #         df, enc_groups)
                
        return self.balance_groups(new_df, random_seed)

    def get_target_train_test(self, shift_hparams):
        return (self.oversample(self.train_df, shift_hparams['oversample_ratio'], shift_hparams['data_seed']), 
            self.oversample(self.test_df, shift_hparams['oversample_ratio'], shift_hparams['data_seed']))
