#!/usr/bin/env python
# -*-coding:utf-8 -*-
import numpy as np 
import torch
import os
import data.prepare_data as pd 
import data.prepare_partition as pp 
import data.prepare_cifar as prepare_cifar
import data.prepare_mnist as prepare_mnist 
import torchvision.transforms as transforms
from torch.utils.data import ConcatDataset
import utils.utils as utils 
import utils.fake_data_generate as fdg 


def load_data_batch(_input, _target, is_on_cuda=True):
    """Args:
    conf: argparse
    _input: [batch_size, channel, imh, imw]
    _target: [batch_size]
    """
    if is_on_cuda == True:
        _input, _target = _input.cuda(), _target.cuda()
    _data_batch = {"input": _input, "target": _target}
    return _data_batch


def define_val_dataset(conf, train_dataset, test_dataset):
    """Args:
    train_dataset: dataset class 
    test_dataset: dataset class 
    """
    assert conf.val_data_ratio >= 0

    partition_sizes = [
        (1 - conf.val_data_ratio) * conf.train_data_ratio,
        (1 - conf.val_data_ratio) * (1 - conf.train_data_ratio),
        conf.val_data_ratio,
    ]
    data_partitioner = pp.DataPartitioner(
        conf,
        train_dataset,
        partition_sizes,
        partition_type="original",
        consistent_indices=False,
        partition_obj=False
    )
    
    tr_index = data_partitioner.partitions[0]
    val_index = data_partitioner.partitions[2]
    print("The number of overlapped index between training and validating", np.sum([v for v in val_index if v in tr_index]))
    
    train_dataset = data_partitioner.use(0)
    # split for val data.
    if conf.val_data_ratio > 0:
        assert conf.partitioned_by_user is False
        val_dataset = data_partitioner.use(2)
        return train_dataset, val_dataset, test_dataset
    else:
        return train_dataset, None, test_dataset
    
    
def define_dataset(conf, display_log=True):
    """Args:
    conf: the argparse 
    data: str, "cifar10", "cifar100", "mnist" .. 
    display_log: bool variable
    Ops:
    the conf.val_data_ratio needs to be larger or equal to zero to create the validation dataset 
    """
    # prepare general train/test.
    conf.partitioned_by_user = True if "femnist" == conf.data else False
    
    train_dataset, test_dataset = pd.get_dsprint_data(opt=conf.load_opt)
    
    # create the validation from train.
    train_dataset, val_dataset, test_dataset = define_val_dataset(
        conf, train_dataset, test_dataset
    )
    # The length of the validation dataset is as same as the train_dataset
    # However, when we iterate over the validation dataset, then we see that the length of the val dataset is 
    # as same as when we define it

    if display_log:
        conf.logger.log(
            "Data stat for original dataset: we have {} samples for train, {} samples for val, {} samples for test.".format(
                len(train_dataset),
                len(val_dataset) if val_dataset is not None else 0,
                len(test_dataset),
            )
        )
    return {"train": train_dataset, "val": val_dataset, "test": test_dataset}


def define_data_loader_dsprite(conf, tr_dataset, is_train, shuffle=True):
    if "add_fake" in conf.align_data and is_train == True and conf.use_original_client_data != "only_real":
        fake_loader = fdg.get_fake_data(conf, 0, conf.use_local_id)
        combine = True 
    else:
        combine = False 
    if combine == True and "add_fake" in conf.align_data:
        print(conf.use_original_client_data)
        if conf.use_original_client_data == "only_sync":
            data_to_load = fake_loader
        elif conf.use_original_client_data == "only_real":
            data_to_load = tr_dataset
        elif conf.use_original_client_data == "combine":
            data_to_load = ConcatDataset([tr_dataset, fake_loader])
    else:
        data_to_load = tr_dataset

    original_data_loader = torch.utils.data.DataLoader(data_to_load, 
                                                        batch_size=conf.batch_size, 
                                                        shuffle=shuffle, 
                                                        num_workers=conf.num_workers, 
                                                        pin_memory=True,
                                                        drop_last=True)
    return original_data_loader


def define_data_loader(conf, dataset, localdata_id=None, is_train=True, shuffle=True, 
                       data_partitioner=None, drop_last=True):
    """Args: 
    conf: the argparse 
    dataset: a dictionary or i.e., train_dataset, val_dataset from the define_dataset function
    localdata_id: client id
    is_train: bool variable 
    shuffle: bool variable 
    data_partitioner: a class: pp.DataPartitioner
    """
    if "add_fake" in conf.align_data and is_train == True and conf.use_original_client_data != "only_real":
        fake_loader = fdg.get_fake_data(conf, len(dataset.index) // conf.n_clients, localdata_id)
        combine = True 
    else:
        fake_loader = None 
        combine = False 
    if is_train:
        world_size = conf.n_clients 
        partition_size = [1.0 / world_size for _ in range(conf.n_clients)]
        assert localdata_id is not None 
        if data_partitioner is None:
            data_partitioner = pp.DataPartitioner(conf, dataset, partition_sizes=partition_size,
                                                  partition_type=conf.partition_type)
        data_to_load = data_partitioner.use(localdata_id)
    else:
        data_to_load = dataset 
        
    if combine == True and "add_fake" in conf.align_data:
        print(conf.use_original_client_data)
        if conf.use_original_client_data == "only_sync":
            data_to_load = fake_loader
        elif conf.use_original_client_data == "only_real":
            data_to_load = data_to_load
        elif conf.use_original_client_data == "combine":
            data_to_load = ConcatDataset([data_to_load, fake_loader])
            # print("combine fake and real data together")
    batch_size = conf.batch_size
    data_loader = torch.utils.data.DataLoader(data_to_load, 
                                            batch_size=batch_size, 
                                            shuffle=shuffle, 
                                            num_workers=conf.num_workers, 
                                            pin_memory=True,
                                            drop_last=drop_last)
    conf.num_batches_per_device_per_epoch = len(data_loader)
    conf.num_whole_batches_per_worker = (
        conf.num_batches_per_device_per_epoch * conf.local_n_epochs
    )    
    return data_loader, data_partitioner


class GetSpecificClass(object):
    def __init__(self, data, targets, cls_of_interst):
        """Get a subset of data based on the index
        Args:
            data: object, full datset 
            index: index, full dataset
        """
        normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        self.transform = transforms.Compose(
            [transforms.ToTensor()] + ([normalize] if normalize is not None else [])
        )
        self.data = data 
        self.targets = targets 
        self.cls_of_interst = cls_of_interst
        self.index = np.where(np.array(self.targets) == self.cls_of_interst)[0]
        print("The length of the dataset", len(self.index))
        
    def __len__(self):
        return len(self.index)
    
    def __getitem__(self, sub_index):
        """Args:
        sub_index: the sub index for a particular partition of the dataset
        """
        data_idx = self.index[sub_index]
        _data = self.data[data_idx]
        if self.transform:
            _data = self.transform(_data)
        return _data, self.targets[self.index[sub_index]]    
    

def get_test_dataset(conf, shuffle=False):
    utils.seed_everything(conf.seed_use)
    if conf.dataset == "dsprite":
        test_dataset, _ = pd.get_dsprite_data_tt()
        # _, test_dataset = pd.get_dsprint_data(opt=conf.load_opt)
    elif conf.dataset == "cifar10" or conf.dataset == "cifar100":
        test_dataset = prepare_cifar.get_dataset(conf, conf.dataset, "../image_dataset/", split="test")
    elif conf.dataset == "mnist":
        test_dataset = prepare_mnist.get_dataset(conf, conf.dataset, "../image_dataset/", split="test")

    # if use_specific_cls <= 9:
    #     data_use = test_dataset.data
    #     targets = test_dataset.targets 
    #     test_dataset = GetSpecificClass(data_use, targets, use_specific_cls)
    tt_data_loader, _ = define_data_loader(conf, dataset=test_dataset, 
                                        localdata_id=0, 
                                        is_train=False,
                                        shuffle=shuffle,
                                        drop_last=False)
    return test_dataset, tt_data_loader 

