import os
import json
import random

from collections import defaultdict, Counter

import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader, Sampler
from torchvision import datasets, transforms
from torch.distributions.categorical import Categorical


class Celeba(Dataset):
    def __init__(self, file_path, tar_att, env_att, target):
        # load the entire celeba data
        print('loading CelebA data')
        print('target attribute is ', tar_att)
        print('env attribute is ', env_att)
        if target:
            print('This is target data')
        else:
            print('This is source data')

        self.train_data = datasets.CelebA(file_path, split='train',
                                          target_type='attr', transform=None,
                                          target_transform=None, download=True)

        self.val_data = datasets.CelebA(file_path, split='valid',
                                        target_type='attr', transform=None,
                                        target_transform=None, download=True)

        self.test_data = datasets.CelebA(file_path, split='test',
                                         target_type='attr', transform=None,
                                         target_transform=None, download=True)

        # get the idx of label and cor
        cor_name = env_att
        self.label_idx = self.train_data.attr_names.index(tar_att)
        self.cor_idx = self.train_data.attr_names.index(cor_name)

        # def train environments
        self.envs = [
            {
                'idx_list': [],
                'data': self.train_data
            },
            {
                'idx_list': [],
                'data': self.train_data
            },
        ]

        train_idx_list, val_idx_list, test_idx_list = self.split_celeba_for_src_tar(
            len(self.train_data.attr), len(self.val_data.attr),
            len(self.test_data.attr), target)

        # obtain training environments based on the provided attribute
        for idx in train_idx_list:
            if self.train_data.attr[idx, self.cor_idx] == 0:
                self.envs[0]['idx_list'].append(idx)
            else:
                self.envs[1]['idx_list'].append(idx)

        # def validation env and test env
        val_env0_list = []
        test_env1_list = []
        for idx in val_idx_list:
            if self.val_data.attr[idx, self.cor_idx] == 0:
                val_env0_list.append(idx)

        for idx in test_idx_list:
            if self.test_data.attr[idx, self.cor_idx] == 0:
                test_env1_list.append(idx)


        self.envs.append({
            'idx_list': val_env0_list,
            'data': self.val_data,
        })

        self.envs.append({
            'idx_list': test_env1_list,
            'data': self.test_data,
        })

        for env in self.envs:
            print('size: ', len(env['idx_list']))

        # for val environment, compute the mask for the given attribute
        self.val_att_idx_dict = {
            cor_name: { '0_0': [], '0_1': [], '1_0': [], '1_1': []}
        }
        for i in range(len(self.val_data.attr)):
            k = '{}_{}'.format(self.val_data.attr[i, self.label_idx],
                               self.val_data.attr[i, self.cor_idx])
            self.val_att_idx_dict[cor_name][k].append(i)

        # compute correlation between each attribute and the target attribute
        # only for the test set
        self.test_att_idx_dict = {}

        if self.test_data.attr_names[-1] == '':
            self.test_data.attr_names = self.test_data.attr_names[:-1]

        for idx, att in enumerate(self.test_data.attr_names):
            if idx == self.label_idx:
                continue

            if idx == self.cor_idx:
                continue

            data_dict = {
                '0_0': [],
                '0_1': [],
                '1_0': [],
                '1_1': [],
            }

            # go through only the att label
            for i, attrs in enumerate(self.test_data.attr):
                if i not in test_env1_list:
                    continue

                k = '{}_{}'.format(attrs[self.label_idx], attrs[idx])
                data_dict[k].append(i)

            # print data stats
            print('{:>20}'.format(att), end=' ')
            for k, v in data_dict.items():
                print(k, ' ', '{:>8}'.format(len(v)), end=', ')

            # ratio_0 = len(data_dict['0_0']) / len(data_dict['0_1'])
            # ratio_1 = len(data_dict['1_0']) / len(data_dict['1_1'])
            # print('    ratios 0_0/1: ', '{:>8.2f}'.format(ratio_0), ' 1_0/1: ',
            #       '{:>8.2f}'.format(ratio_1))
            print()

            self.test_att_idx_dict[att] = data_dict

        self.length = len(self.train_data.attr) + len(self.val_data.attr) + len(self.test_data.attr)

        self.latent_domain = None  # for dg_mmld

    @staticmethod
    def split_celeba_for_src_tar(train_len, val_len, test_len, target):
        '''
          use half of the data for the source task and half of the data for the
          target task
        '''
        train_idx_list = list(range(train_len))
        val_idx_list = list(range(val_len))
        test_idx_list = list(range(test_len))

        # random shuffle
        random.shuffle(train_idx_list)
        random.shuffle(val_idx_list)
        random.shuffle(test_idx_list)

        # take the first half if source else second half
        if target:
            return (train_idx_list[:len(train_idx_list)//2],
                    val_idx_list[:len(val_idx_list)//2],
                    test_idx_list[:len(test_idx_list)//2])
        else:
            return (train_idx_list[len(train_idx_list)//2:],
                    val_idx_list[len(val_idx_list)//2:],
                    test_idx_list[len(test_idx_list)//2:])

    def __len__(self):
        return self.length

    def __getitem__(self, keys):
        idx = []
        for key in keys:
            env_id = int(key[1])  # this doesn't matter for Pubmed data
            idx.append(key[0])

        batch = {}
        batch['Y'] = self.envs[env_id]['data'].attr[:, self.label_idx][idx]
        batch['C'] = self.envs[env_id]['data'].attr[:, self.cor_idx][idx]
        batch['idx'] = torch.tensor(idx).long()

        if self.latent_domain is not None and env_id in self.latent_domain:
            # return the latent domain for dg mmld
            batch['D'] = self.latent_domain[env_id][idx]

        # convert text into a dictionary of np arrays
        img2tensor = transforms.ToTensor()
        transform = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))

        x = []
        for i in idx:
            img = img2tensor(self.envs[env_id]['data'][i][0])
            img = transform(img)
            x.append(img)

        batch['X'] = torch.stack(x)

        return batch

    def get_all_y(self, env_id):
        return self.envs[env_id]['data'].attr[self.envs[env_id]['idx_list'], self.label_idx].tolist()

    def get_all_c(self, env_id):
        return self.envs[env_id]['data'].attr[self.envs[env_id]['idx_list'],
                                              self.cor_idx].tolist()

    def set_domain_label(self, env_id, idx_list, domain_list):
        # set latent domain for dg mmld
        self.latent_domain = {
            env_id: torch.zeros_like(self.envs[env_id]['data'].attr[:,0])
        }
        self.latent_domain[env_id][idx_list] = torch.LongTensor(domain_list)

        return

    def get_all_att(self, env_id):
        return self.envs[env_id]['data'].attr

    def get_att_names(self, i):
        return self.envs[0]['data'].attr_names[i]
