#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Jul  1 12:21:48 2024

@author: XXXX
"""

import torch
import numpy as np
from torch.utils.data import Dataset
import os


class TaskSequenceDataset(Dataset):

    def __init__(self, args, task_samples=1e4, name=None):      
        # Initialise operations
        self.set_operations()
        # Initialise parameters
        self.set_parameters(args)
        # Initialise tasks
        self.set_tasks()
        
        # Prepare dataset
        self.set_dataset(task_samples, args['data_dir']) 

    def set_operations(self):
        # Set operations and their properties
        def s0(state, data):
            return np.roll(state,0) + data
        def s1(state, data):
            return np.roll(state,1) + data
        def s2(state, data):
            return np.roll(state,2) + data
        def s3(state, data):
            return np.roll(state,3) + data
        def s4(state, data):
            return np.roll(state,4) + data
        def s5(state, data):
            return np.roll(state,5) + data
        # Stick them all in a list
        self.operations = [s0, s1, s2, s3, s4, s5]        
        # Define durations for each operation: number of steps of the same operation
        self.durations = [3,3,4,4,5,5]    
        
    def set_parameters(self, args):
        # Get relevant parameters from input dictionary
        self.do_test = args['do_test']
        self.n_contexts = args['n_contexts']        
        self.n_steps = int(np.mean(self.durations) * args['n_contexts']) \
            if args['n_steps'] == -1 else args['n_steps']
        self.n_dims_in = args['n_dims']
        self.n_dims_out = args['n_dims']        
        self.n_operations = len(self.operations)        
        self.n_tasks = self.n_operations ** self.n_contexts
        self.task_train = args['task_train']
        self.task_start = int(self.task_train*self.n_tasks) if self.do_test else 0
        self.task_stop = self.n_tasks if self.do_test else int(self.task_train*self.n_tasks)
        self.task_choose = args['task_choose']
        if self.task_choose > -1:
            self.task_start = self.task_choose
            self.task_stop = self.task_choose+1
        self.sparse_feedback = args['sparse_feedback']

    def set_tasks(self):
        # Create tasks
        self.tasks = self.get_tasks()
        # Get the task operations for easier processing later
        self.task_names = ['-'.join([c.__name__[0:2] for c in t]) for t in self.tasks]
        self.task_ops = [[self.operations.index(c) for c in t] for t in self.tasks]
        self.task_op_durs = [[self.durations[o] for o in t] for t in self.task_ops]        
        # Predefine task vectors
        self.task_vec = torch.eye(self.n_tasks, dtype=torch.float)
        # Predefine context vectors
        self.task_contexts = torch.zeros((self.n_tasks, self.n_steps, self.n_dims_in), 
                                         dtype=torch.float)
        for task, ops in enumerate(self.task_ops):
            t = 0
            for op, duration in zip(ops, [self.durations[o] for o in ops]):
                self.task_contexts[task, t:(min(t+duration, self.n_steps)), :] = \
                    torch.eye(self.n_operations, dtype=torch.float)[op]
                t = t + duration

    def set_dataset(self, task_samples, base):
        # Specify some dimensions for the data files
        self.task_samples = int(task_samples)
        # Create lists of all data arrays as memmaps
        self.data_in, self.data_out = [[[] for _ in self.tasks] for _ in range(2)]
        # Run through each task
        for t in range(self.task_start, self.task_stop):
            task = self.task_names[t]
            # Construct task directory from name within base directory
            task_dir = os.path.join(base, task)
            # Construct folder for current dataset
            in_file = os.path.join(task_dir, 'input_data.npy')
            out_file = os.path.join(task_dir, 'target_data.npy')
            # Generate data for this task if it doesn't exist yet
            if not os.path.isdir(task_dir):
                # Generate task input and output
                input_data, output_data = self.generate_data([t for _ in range(self.task_samples)])
                # Write the task input and output to numpy files
                os.makedirs(task_dir)     
                np.save(in_file, input_data)
                np.save(out_file, output_data)
                # Display progress
                print(f'Generated task {task}, {t} / {self.n_tasks}')                
            # Then specify numpy memmap to load data without filling up RAM
            self.data_in[t] = np.lib.format.open_memmap(
                in_file, dtype='float32', mode='readonly', 
                shape=(self.task_samples, self.n_steps, self.n_dims_in))
            self.data_out[t] = np.lib.format.open_memmap(
                out_file, dtype='float32', mode='readonly', 
                shape=(self.task_samples, self.n_steps, self.n_dims_out))

    def generate_data(self, data_tasks):
        # Generate data input: one random vector in each step (batches x timesteps x dims)
        input_data = np.random.randn(len(data_tasks), self.n_steps, self.n_dims_in)
        # Generate output data: task applied to previous step
        output_data = np.zeros_like(input_data)
        for b, (task, ops) in enumerate([[self.tasks[d], self.task_ops[d]] for d in data_tasks]):
            # Each operation may have different duration. Get durations and onset for this task
            op_durations = [self.durations[o] for o in ops]
            op_onsets = [0] + [i for i in np.cumsum(op_durations[:-1])]
            # Generate output data for each operation in this task
            for operation, o, steps, onset in zip(task, ops, op_durations, op_onsets):
                if onset == 0:
                    previous_step = np.zeros(self.n_dims_in)
                for s in range(steps):
                    if onset + s < self.n_steps:
                        output_data[b, onset + s, :] = \
                            operation(previous_step, input_data[b, onset + s, :])
                        previous_step = output_data[b, onset + s, :]
        # Output data as torch tensors so they can be used for torch models
        return input_data, output_data

    def generate_labels(self, data_tasks):
        # Generate context input: one-hot encoding of context (batches x timesteps x operations)
        input_context = self.task_contexts[data_tasks]
        # Create one feedback step per context
        if self.sparse_feedback:
            input_context = self.keep_one_context(input_context, data_tasks)       
        # Generate task input: one-hot encoding of task (batches x tasks)
        input_task = self.task_vec[data_tasks]
        # Return both
        return input_context, input_task
        
    def keep_one_context(self, contexts, tasks):
        # Assume context has shape tasks x timestep x operations
        # I will sub-select exactly one timestep for feedback in each context
        # If there's no batch dimension: add it (and remove it again later)
        if contexts.ndim == 2:
            contexts = contexts.unsqueeze(0)
        if isinstance(tasks, int):
            tasks = [tasks]
        # Sample time point indices to keep
        durations = np.stack([self.task_op_durs[t] for t in tasks])
        select = np.random.randint(durations)
        select[:,1:] += np.cumsum(durations, axis=-1)[:,:-1]
        # Turn selected indices into a task x timestep mask
        mask = torch.zeros(contexts.shape[:-1], dtype=torch.bool)
        batch_idx = torch.arange(contexts.shape[0]).unsqueeze(1).expand_as(torch.tensor(select))
        mask[batch_idx, select] = True
        # Return the masked contexts
        return torch.squeeze(contexts * mask.unsqueeze(-1))
    
    def get_model_input(self, data_tasks):
        # Get input and output data for specified tasks
        input_data, output_data = (torch.tensor(y, dtype=torch.float) 
                                   for y in self.generate_data(data_tasks))
        # Get corresponding task and context signals
        input_context, input_task = self.generate_labels(data_tasks)
        # Return them all
        return input_data, input_context, input_task, output_data
    
    # Create tasks based on input arguments
    def get_tasks(self):
        # Fix the random seed so you get the same tasks across training
        np.random.seed(0)
        # Sample all possible tasks
        tasks = [k for k in self.enumerate_tasks(self.operations, N_contexts=self.n_contexts).keys()]
        # And shuffle them randomly
        np.random.shuffle(tasks)
        return tasks

    # To sample train and test tasks, I want to enumerate all possibilities, with probabilities
    def enumerate_tasks(self, operations, transitions=None, N_contexts=3):
        # If transitions not defined: set them to perfectly uniform
        transitions = [[1/len(operations) for _ in operations] for _ in operations] \
            if transitions is None else transitions
        # Recursively build tasks while tracking probability
        tasks = {}
        for o in operations:
            self.extend_tasks(tasks, [o], operations, transitions, 1/len(operations), N_contexts)
        # And return the list, plus a dictionary where each task is the key and the prob the value
        return tasks
    
    # Recursively extend tasks, keeping track of the task's probability, until reaching max depth
    def extend_tasks(self, tasks, task, operations, transitions, p, max_depth):
        # Transition to each possible operation from the last one
        for o, p_o in zip(operations, transitions[operations.index(task[-1])]):
            if p_o > 0:
                new_task = task + [o]
                new_p = p * p_o
                if len(new_task) == max_depth:
                    tasks[tuple(new_task)] = new_p
                else:
                    self.extend_tasks(tasks, new_task, operations, transitions, new_p, max_depth)

    def __len__(self):
        return self.task_samples * (self.task_stop - self.task_start)

    def __getitem__(self, idx):
        # Get the current task from the index (which indicates row of data file)
        task = int(idx/self.task_samples) + self.task_start
        example = idx % self.task_samples
        # And grab the corresponding task and context for that task
        input_context, input_task = self.generate_labels(task)        
        # Stick everything in a dictionary and return
        return {'input': torch.tensor(self.data_in[task][example], dtype=torch.float), 
                'output': torch.tensor(self.data_out[task][example], dtype=torch.float), 
                'context': input_context, 'task': input_task, 'id': task}                

# Create a sampler that balances the tasks per batch
class UniformSampler():
    def __init__(self, dataset, n_batches, batch_size):
        # Copy dataset properties
        self.n_tasks = dataset.task_stop - dataset.task_start
        self.task_samples = dataset.task_samples
        self.batch_size = batch_size
        self.n_batches = n_batches

    def __iter__(self):
        for i in range(self.n_batches):
            # Shuffle list of tasks
            sampled_tasks = np.random.permutation(self.n_tasks)
            # Concatenate them so to fill the batch. 
            # The shuffling makes sure that remaining batch slots are divided uniformly
            sampled_tasks = np.tile(sampled_tasks, np.ceil(self.batch_size / self.n_tasks).astype(int))[:self.batch_size]
            # Sample random examples for each task
            sampled_examples = np.random.randint(self.task_samples, size=sampled_tasks.shape)
            # Then return the sample indices in the task dataset
            yield sampled_tasks * self.task_samples + sampled_examples
        
    def __len__(self):
        return self.n_batches
