from dataclasses import dataclass
import torch
from torch.utils.data import Dataset, Subset
import logging
from typing import Sequence, Dict

from util import set_seed

@dataclass
class Synthetic_DataCollator_Sample_With_Context(object):
    """
    Samples observations with replacement
    """
    use_dataset_Y: False
    num_clicks: 500
    one_X_per_column: False

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        Z, X, click_rates, click_obs = \
                tuple([instance[key] for instance in instances] \
                for key in ("Z", "X", "click_rates", "click_obs"))
        
        X = torch.concatenate([x.unsqueeze(0) for x in X], 0)
        Z = torch.concatenate([x.unsqueeze(0) for x in Z], 0)
        click_rates = torch.concatenate([x.unsqueeze(0) for x in click_rates], 0)
        assert len(X.shape) == 3 # D, N, dimension
        N = X.shape[1]
        B = X.shape[0]
        if self.one_X_per_column:
            indices = torch.randint(0, N, size=(self.num_clicks,)).unsqueeze(0).repeat(B,1)
        else:
            indices = torch.randint(0, N, size=(B,self.num_clicks)) 
        X_indices = indices.unsqueeze(-1).repeat(1,1,X.shape[-1])
        X = torch.gather(X, dim=1, index=X_indices)

        # permute click rates, regardless of whether we re-draw click obs
        if click_rates is not None:
            click_rates = torch.gather(click_rates, dim=1, index=indices)[:,:self.num_clicks]

        if not self.use_dataset_Y and click_rates is not None:
            # if we re-draw Y's, use already permuted click_obs
            click_obs = torch.bernoulli(click_rates)
        else:
            # otherwise, permute existing click_obs
            click_obs = torch.cat([x.unsqueeze(0) for x in click_obs])
            click_obs = torch.gather(click_obs, dim=1, index=indices)[:,:self.num_clicks]
        return dict(
            Z = Z,
            click_obs = click_obs,
            click_rates = click_rates,
            X = X,
            click_length_mask = torch.ones_like(click_obs)
        )

@dataclass
class Synthetic_DataCollator_Fixed_With_Context(object):
    """
    Uses fixed observation click sequence
    """

    num_clicks: 500
    
    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        Z, X, click_obs, click_rates = \
                tuple([instance[key] for instance in instances] \
                for key in ("Z", "X", "click_obs", "click_rates"))
        X = torch.concatenate([x.unsqueeze(0) for x in X], 0)
        Z = torch.concatenate([x.unsqueeze(0) for x in Z], 0)

        if click_rates is not None: 
            click_rates = torch.concatenate([x.unsqueeze(0) for x in click_rates], 0)
            click_rates = click_rates[:,:self.num_clicks]
        click_obs = torch.concatenate([x.unsqueeze(0) for x in click_obs], 0)
        click_obs = click_obs[:,:self.num_clicks]
        return dict(
            Z = Z,
            click_obs = click_obs,
            click_rates = click_rates,
            X = X[:,:self.num_clicks,:],
            click_length_mask = torch.ones_like(click_obs)
        )


class VectorAndClickRateDatasetWithContextFromDict(Dataset):
    """
    Dataset object that makes data loaders (can be used for train and/or eval)
    Bootstrap resamples (X,Y) pairs, for each Z. If we need deterministic evals, 
    those are generated during __init__
    
    Currently assumes fixed column length, which is appropriate for synthetic examples
    but probably not for real data
    """

    def __init__(self, Z, X, Y, click_rates,
                 num_loader_obs=500, 
                 generator_seed=230498,
                 bootstrap_seed=None,
                 use_dataset_Y=False, # use dataset Y instead of click rates to resample
                 one_X_per_column=False): # use same X ordering across rows (e.g. for training something like DPT)

        assert Y is not None or click_rates is not None 
        
        self.click_rates = click_rates
        self.Z = Z
        self.X = X
        self.num_rows = X.shape[0] # number of rows
        self.num_cols = X.shape[1] # number of columns
        self.loader_obs = Y
        self.num_loader_obs = num_loader_obs
        self.use_dataset_Y = use_dataset_Y
        self.one_X_per_column = one_X_per_column
        logging.info(f"Total rows: {len(self.Z)}")

        # Generate a fixed sequence of observations (can be used for eval) =======================
        generator = torch.Generator()
        generator.manual_seed(generator_seed)

        num_rows = len(self.Z)
        # shuffle data (sample WITHOUT replacement here; this is just to shuffle the data once ahead of time)
        # don't worry; in the collators that sample, we will sample with replacement across each row. 
        self.bootstrap_row_idxs = torch.arange(0, self.num_rows)
        self.bootstrap_col_idxs = torch.arange(0, self.num_cols).unsqueeze(0).repeat((self.num_rows,1))
        if bootstrap_seed is not None:
            boot_generator = torch.Generator()
            boot_generator.manual_seed(bootstrap_seed+238954)
            self.bootstrap_row_idxs = torch.randint(high=num_rows, size=(num_rows,),
                    dtype=torch.int64, generator=boot_generator)
            x = torch.rand(self.num_rows, self.num_cols, generator=boot_generator)
            indices = torch.argsort(x, dim=-1)
            self.bootstrap_col_idxs = indices
        
        # Make a fixed subset of the dataset (if necessary)
        # by first permuting the order of the articles, and then choosing the first rows
        # this is not fixed across different bootstrap seeds
        self.fixed_article_subset_order = torch.randperm(self.num_rows, generator=generator)

    def __len__(self):
        return len(self.click_rates)

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        res = dict(
                    X = self.X[self.bootstrap_row_idxs[i]],
                    click_obs = self.loader_obs[self.bootstrap_row_idxs[i]][self.bootstrap_col_idxs[i]],
                    click_rates = self.click_rates[self.bootstrap_row_idxs[i]][self.bootstrap_col_idxs[i]],
                    Z = self.Z[self.bootstrap_row_idxs[i]])
        return res
    def make_loader(self, batch_size, train=False, num_subset_rows=None, train_deterministic_row_order=False):
        if num_subset_rows is not None:
            # Take a subset of the number of rows
            idxs = self.fixed_article_subset_order[:num_subset_rows]
            ds = Subset(self, idxs)
        else:
            ds = self

        if train:
            collate_fn = Synthetic_DataCollator_Sample_With_Context(num_clicks=self.num_loader_obs, 
                    use_dataset_Y=self.use_dataset_Y, 
                    one_X_per_column=self.one_X_per_column)
        else:
            collate_fn = Synthetic_DataCollator_Fixed_With_Context(num_clicks=self.num_loader_obs)

        if train and train_deterministic_row_order:
            dl = torch.utils.data.DataLoader(ds,
                   batch_size=batch_size,
                   collate_fn=collate_fn, shuffle=False)
        else:
            dl = torch.utils.data.DataLoader(ds,
                   batch_size=batch_size,
                   collate_fn=collate_fn, shuffle=train)
        return dl


class VectorAndClickRateDatasetWithContext(VectorAndClickRateDatasetWithContextFromDict):
    def __init__(self, dataset_file,
                 num_loader_obs=500,
                 generator_seed=230498,
                 bootstrap_seed=None,
                 use_dataset_Y=False,
                 one_X_per_column=False):

        dataset_data = torch.load(dataset_file)
        Z = dataset_data['Z']
        X = dataset_data['X']
        Y = dataset_data['Y']
        #if 'Y' in dataset_data.keys():
        #    Y = dataset_data['Y']
        #else:
        #    Y = None
        click_rates = dataset_data['click_rate']
        assert len(Z) == len(click_rates) == len(Y)
        assert X.shape[1] == click_rates.shape[1] == Y.shape[1]
        super().__init__(Z, X, Y, click_rates, num_loader_obs, generator_seed, bootstrap_seed, use_dataset_Y, one_X_per_column)


# no splits implemented
def get_loaders_synthetic_with_context(config, train_deterministic_row_order=False, extras=True):
    set_seed(config.seed)
    logging.info('Making train dataset')
    if not hasattr(config, 'use_dataset_Y'):
        config.use_dataset_Y = False
    if not hasattr(config, 'one_X_per_column'):
        config.one_X_per_column = False

    train_kwargs = {
            'bootstrap_seed': config.bootstrap_seed, 
            'use_dataset_Y': config.use_dataset_Y, 
    }
    train_path = config.data_dir + '/train_data.pt'
    print(train_path)
    train_dataset = VectorAndClickRateDatasetWithContext(train_path, **train_kwargs, one_X_per_column=config.one_X_per_column)
    train_loader = train_dataset.make_loader(
        batch_size=config.batch_size, train=True, train_deterministic_row_order=train_deterministic_row_order)

    set_seed(config.seed)
    logging.info('Making eval dataset')
    eval_path = config.data_dir + '/eval_data.pt'
    print(eval_path)
    eval_dataset = VectorAndClickRateDatasetWithContext(eval_path)
    val_loader = eval_dataset.make_loader(
            batch_size=config.eval_batch_size, train=False)

    if config.one_X_per_column:
        X = train_dataset.X
        assert torch.allclose(X[0][0], X[1][0], 1e-4)
        X = eval_dataset.X
        assert torch.allclose(X[0][0], X[1][0], 1e-4)
    print("train dataset size: {}".format(len(train_dataset)))
    print("eval dataset size: {}".format(len(eval_dataset)))

    # at every epoch, evaluate not only on the val set, but also a fixed subset of the training set
    # this is to measure overfitting

    train_subset_rows = len(eval_dataset)
    train_fixed_subset_loader = train_dataset.make_loader(
            batch_size=config.batch_size,
            train=False,
            num_subset_rows=train_subset_rows)

    res = {'train_loader':train_loader, 'val_loader':val_loader,
            'train_fixed_subset_loader':train_fixed_subset_loader,
            'train_dataset': train_dataset,
            'val_dataset':eval_dataset}

    if config.extra_eval_data is not None:
        assert not config.embed_data_dir # not implemented
        extra_eval_dataset = VectorAndClickRateDatasetWithContext(config.extra_eval_data)
        extra_eval_loader = extra_eval_dataset.make_loader(batch_size=config.batch_size, train=False)
        res['extra_eval_dataset'] = extra_eval_dataset
        res['extra_eval_loader'] = extra_eval_loader

    if extras:
        pass # we do nothing here, argument mostly for backwards compatibility
        
    return res
