#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Sep  5 10:16:16 2024

@author: XXXX
"""

import torch
import numpy as np

class ContinuousMaze():
    def __init__(self, args=None):
        # Initialise operations: different path segments
        self.set_operations()        
        # Initialise task arguments
        self.set_parameters(args)
        # Initialise tasks
        self.set_tasks()
        # Send to device
        self.set_device()

    def set_device(self, device=None):
        # Transfer all relevant tensors to device
        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') \
            if device is None else device     
        # List all the tensors that need to go to gpu
        self.data_coord = self.data_coord.to(self.device)
        self.data_context = self.data_context.to(self.device)
        self.task_init = self.task_init.to(self.device)
        
    def set_operations(self):
        # Fix the random seed so you get the same operations across training
        np.random.seed(0)                
        self.durations = [3,3,4,4,5,5]
        # Each operation goes into different radial directions
        dirs = np.arange(len(self.durations)) * 2*np.pi/len(self.durations)
        dir_var = 0.1*2*np.pi/max(self.durations)
        len_var = 0.5/max(self.durations)
        # Generate actions for each operation
        self.operations = []
        for i, d in enumerate(self.durations):
            curr_dir = torch.tensor(dirs[i] + (-1 + 2 * np.mod(i,2)) * dir_var * np.arange(d))            
            curr_len = torch.tensor(0.5 + len_var * (np.mod(i,2) * d + (1 - 2*np.mod(i,2)) * np.arange(d)))
            self.operations.append(torch.stack([
                curr_len * torch.cos(curr_dir),
                curr_len * torch.sin(curr_dir)], -1))
        self.n_operations = len(self.operations)   

    def set_parameters(self, args=None):
        # Set input defaults
        args = self.default_parameters() if args is None else args
        # Get relevant parameters from input dictionary
        self.do_test = args['do_test']
        self.n_contexts = args['n_contexts']
        self.n_tasks = self.n_operations ** self.n_contexts
        self.n_steps = int(np.mean(self.durations) * args['n_contexts']) \
            if args['n_steps'] == -1 else args['n_steps']              
        self.n_dims_in = 2
        self.n_dims_out = self.operations[0].shape[-1]
        self.task_train = args['task_train']
        self.task_start = int(self.task_train*self.n_tasks) if self.do_test else 0
        self.task_stop = self.n_tasks if self.do_test else int(self.task_train*self.n_tasks)
        self.task_choose = args['task_choose']
        if self.task_choose > -1:
            self.task_start = self.task_choose
            self.task_stop = self.task_choose+1
        self.sparse_feedback = args['sparse_feedback']

    def default_parameters(self):
        return {'do_test': False, 'n_tasks': 50, 'n_contexts':3, 'n_steps':-1, 'task_train': 0.6, 'task_sampling': 4}

    def set_tasks(self):
        # Create tasks
        self.tasks = self.get_tasks()
        # Get the task operations for easier processing later
        self.task_ops = np.array([[i for c in t for i, o in enumerate(self.operations) if torch.equal(c, o)] 
                                  for t in self.tasks])        
        self.task_names = ['-'.join([str(c) for c in t]) for t in self.task_ops]
        # Predefine task vectors
        self.task_vec = torch.eye(self.n_tasks, dtype=torch.float)
        # Set initial location
        self.task_init = torch.tensor([[0,0]], dtype=torch.float)
        # Calculate path and plot for all tasks
        self.task_path = []
        self.task_actions = []
        self.task_contexts = []
        for t_i, t in enumerate(self.tasks):
            path = [self.task_init]
            actions = []
            contexts = []
            for c_i, c in enumerate(t):
                for a in c:
                    contexts.append(self.task_ops[t_i][c_i])
                    actions.append(a)
                    path.append(self.transition(path[-1], a))
            self.task_path.append(torch.cat(path[1:], 0))
            self.task_actions.append(actions)
            self.task_contexts.append(contexts)
        # Define task grid as extending task_reach in all directions
        self.task_radius = torch.max(torch.abs(torch.cat(self.task_path)))
        # Get tensor of path coordinates
        self.data_coord = []
        self.data_context = []
        for p, c in zip(self.task_path, self.task_contexts):
            new_coords = torch.zeros(self.n_steps, 2)
            new_context = torch.zeros(self.n_steps, self.n_operations)
            new_len = min(len(p), self.n_steps)
            new_coords[:new_len] = p[:new_len]
            new_context[:new_len] = torch.eye(self.n_operations)[c[:new_len]]
            self.data_coord.append(new_coords)
            self.data_context.append(new_context)
        self.data_coord = torch.stack(self.data_coord)
        self.data_context = torch.stack(self.data_context)
        # Get boundary locations for each task
        self.task_boundaries = [[p[int(i)-1] for i in np.cumsum([self.durations[o] for o in to])] 
                                for to, p in zip(self.task_ops, self.task_path)]
            
    # Create tasks based on input arguments
    def get_tasks(self):
        # Fix the random seed so you get the same tasks across training
        np.random.seed(0)
        # Get all valid tasks, with their probability
        tasks, probs = self.enumerate_tasks(
            operations=self.operations, N_contexts=self.n_contexts)
        # Then shuffle the tasks
        np.random.shuffle(tasks)            
        return tasks

    # To sample train and test tasks, I want to enumerate all possibilities, with probabilities
    def enumerate_tasks(self, operations=None, transitions=None, N_contexts=3):
        # Use default operations if not provided
        operations = self.operations if operations is None else operations
        # If transitions not defined: set them to perfectly uniform
        transitions = [[1/len(operations) for _ in operations] for _ in operations] \
            if transitions is None else transitions
        # Recursively build tasks while tracking probability
        tasks, probs = [], []
        for o in range(len(operations)):
            self.extend_tasks(tasks, probs, [o], transitions, 1/len(operations), N_contexts)
        # And return the list, plus a dictionary where each task is the key and the prob the value
        return [[operations[o] for o in t] for t in tasks], probs
    
    # Recursively extend tasks, keeping track of the task's probability, until reaching max depth
    def extend_tasks(self, tasks, probs, task, transitions, p, max_depth):
        # In the strange case that tasks are already at max depth (e.g. tasks of length 1): terminate
        if len(task) == max_depth:
            tasks.append(task)
            probs.append(p)
        else:
            # Transition to each possible operation from the last one
            for o, p_o in enumerate(transitions[task[-1]]):
                if p_o > 0:
                    new_task = task + [o]
                    new_p = p * p_o
                    if len(new_task) == max_depth:
                        tasks.append(new_task)
                        probs.append(new_p)
                    else:
                        self.extend_tasks(tasks, probs, new_task, transitions, new_p, max_depth)

    def keep_one_context(self, contexts, tasks):
        # Assume context has shape tasks x timestep x operations
        # I will sub-select exactly one timestep for feedback in each context
        # If there's no batch dimension: add it (and remove it again later)
        if contexts.ndim == 2:
            contexts = contexts.unsqueeze(0)
        if isinstance(tasks, int):
            tasks = [tasks]
        # Sample time point indices to keep
        durations = np.stack([self.task_op_durs[t] for t in tasks])
        select = np.random.randint(durations)
        select[:,1:] += np.cumsum(durations, axis=-1)[:,:-1]
        # Turn selected indices into a task x timestep mask
        mask = torch.zeros(contexts.shape[:-1], dtype=torch.bool)
        batch_idx = torch.arange(contexts.shape[0]).unsqueeze(1).expand_as(torch.tensor(select))
        mask[batch_idx, select] = True
        # Return the masked contexts
        return torch.squeeze(contexts * mask.unsqueeze(-1))

    def get_path(self, actions, start=None):
        path = [torch.tensor([[0,0]], dtype=torch.float) if start is None else start]
        for a in actions:
            path.append(self.transition(path[-1], a))
        return torch.cat(path)
                    
    def transition(self, s, a):
        # States s are 2d coordinates, a are 2d translations
        return s + a 
            
    def plot_ops(self):
        from matplotlib import pyplot as plt        
        plt.figure(figsize=((self.n_operations*3,3)))
        for i, o in enumerate(self.operations):
            x_prev = 0*o[0]
            ax = plt.subplot(1, self.n_operations, i+1)
            for x_o in o.cpu().numpy():
                plt.plot([x_prev[0], x_prev[0] + x_o[0]], [x_prev[1], x_prev[1] + x_o[1]])
                x_prev = x_prev + x_o
            ax.set_xticks([])
            ax.set_yticks([])
            ax.axis('equal')                
    
    def plot_tasks(self, n_to_plot=None, n_cols=10, offset=0):
        n_to_plot = len(self.tasks) if n_to_plot is None else n_to_plot
        from matplotlib import pyplot as plt
        plt.figure()
        for i, (path, context) in enumerate(zip(self.task_path[offset:(offset+n_to_plot)],
                                                self.task_contexts[offset:(offset+n_to_plot)])):
            ax = plt.subplot(int(np.ceil(n_to_plot/n_cols)), n_cols, i+1)
            self.plot_task(ax, path, context)
        # plt.tight_layout()
    
    def plot_path(self, ax, path, context=None, linewidth=1, linestyle='solid', marker=""):
        if context is None:
            # Plot path from dark to light
            colours = [[i/len(path)]*3 for i in range(len(path))]
        else:
            # Plot path with colour indicated by context               
            import matplotlib
            # Get tab10 color map: discrete series of 10 colours
            tab10 = matplotlib.colormaps['tab10']
            # Collect colors along context
            colours = [tab10(c) for c in context]
        # Plot each line segment
        for c_i, (c_from, c_to) in enumerate(zip(path[:-1], path[1:])):
            ax.plot([c_from[0], c_to[0]], [c_from[1], c_to[1]], 
                    color=colours[c_i], linewidth=linewidth, 
                    linestyle=linestyle, marker=marker, ms=0.5)

    def plot_task(self, ax, path, context, feedback=None, plot_path=True):
        # Get path start location, which isn't included in path
        init = self.task_init.cpu()
        # Plot big black dot at path start
        ax.scatter(init[:,0], init[:,1], color=[0,0,0], 
                   marker='*', s=150)
        if feedback is not None:
            # Plot big grey dot at feedback locations
            ax.scatter(path[feedback,0], path[feedback,1], color=[0.7,0.7,0.7], 
                       marker='o', s=75)              
        # Color path by context, and make it extra wide
        self.plot_path(ax, torch.cat([init, path]), context, 
                       linewidth=4, marker='o')   
        # Set limits to zoom around the actual path
        ax.set_xlim([torch.min(path[:,0])-1, torch.max(path[:,0])+1])
        ax.set_ylim([torch.min(path[:,1])-1, torch.max(path[:,1])+1])
        # Create clean canvas
        ax.set_xticks([])
        ax.set_yticks([])
        ax.axis('equal')
        
    def get_model_input(self, tasks):
        # Get paths and contexts for requested tasks
        contexts = self.data_context[tasks]
        targets = self.data_coord[tasks]
        
        # Create one feedback step per context
        if self.sparse_feedback:
            contexts = self.keep_one_context(contexts, tasks)       
        
        # Return a dictionary of inputs
        return {'output': targets, 'context': contexts, 'id': tasks}