"""
author: Anonymous
"""

import os
import sys
import inspect

currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0, parentdir)

import numpy as np
import torch
import torch.nn as nn

import time
from scheduling.environment import SchedulingEnvironment
from scheduling.agent import Agent
from scheduling.task import Task

class GeneticAlgorithm():
    def __init__(self, solver=None, num_generations=1, population_size=100, survival_rate=0.1):
        """ Genetic Algorithm Metaheuristic for Task Scheduling using a pre-existing solver as warm start"""
        if solver is None:
            # random initialization
            self.name = "Genetic Algorithm with Random Initialization"
        else:
            self.name = f"Genetic Algorithm with {solver.name}"
        self.solver = solver
        self.num_generations = num_generations
        self.population_size = population_size
        self.survival_rate = survival_rate
        
    def set_environment(self, environment, rerun=False):
        self.env = environment
        if self.solver is None:
            # randomly generate a schedule [(task_id, agent_id), ...]
            num_tasks = self.env.num_tasks
            num_agents = self.env.num_agents
            task_ids = np.arange(num_tasks)
            agent_ids = np.arange(num_agents)
            schedule_tasks = np.random.choice(task_ids, num_tasks, replace=False)
            schedule_agents = np.random.choice(agent_ids, num_tasks, replace=True)
            schedule = [(task_id, agent_id) for task_id, agent_id in zip(schedule_tasks, schedule_agents)]
        else:
            schedule = []
            obs = environment.reset()[0]
            for i in range(self.env.num_tasks):
                action = self.solver.get_action(obs)[0]
                obs, _, _, _ = self.env.step(action)
                schedule.append(action)
        # we have the schedule, now we need to optimize it using Genetic Algorithm
        self.schedule, score = self.genetic_algorithm(schedule)
        print(f"Expected score: {score}, {self.schedule}")
        self.env.reset()
    
    def genetic_algorithm(self, schedule):
        start = time.time()
        current_population = [schedule]
        current_scores = [self.get_reward(schedule)]
        for generation in range(self.num_generations):
            print(f"Generation {generation + 1}/{self.num_generations}")
            # generate new population to fill the population
            for i in range(len(current_population), self.population_size):
                new_schedule = self.mutate(current_population[np.random.randint(len(current_population))])
                current_population.append(new_schedule)
                current_scores.append(self.get_reward(new_schedule))
            # sort the population by scores (higher is first)
            sorted_indices = np.argsort(current_scores)[::-1]
            # print(sorted_indices)
            # print current scores ordered by sorted_indices
            # print(np.array(current_scores)[sorted_indices], [current_population[i] for i in sorted_indices])
            if generation == self.num_generations - 1: # last generation
                return current_population[sorted_indices[0]], current_scores[sorted_indices[0]]
            # select the best performing self.survival_rate of the population
            current_population = [current_population[i] for i in sorted_indices[:int(self.survival_rate * self.population_size)]]
            current_scores = [current_scores[i] for i in sorted_indices[:int(self.survival_rate * self.population_size)]]
            current_time = time.time()
            if current_time - start > 60*60: # 1 hour timeout
                print(f"Time limit exceeded, stopping at generation {generation + 1}")
        return current_population[0], current_scores[0]
    
    def mutate(self, schedule):
        new_schedule = schedule.copy()
        # choose if we are going to mutate the agents or the tasks
        if np.random.rand() < 0.5:
            # randomly change a single agent
            task_index = np.random.randint(len(schedule))
            agent_id = np.random.randint(self.env.num_agents)
            new_schedule[task_index] = (schedule[task_index][0], agent_id)
        else:
            # mutate tasks
            task_indices = np.random.choice(len(schedule), 2, replace=False)
            tmp = new_schedule[task_indices[0]]
            new_schedule[task_indices[0]] = new_schedule[task_indices[1]]
            new_schedule[task_indices[1]] = tmp
            # new_schedule[task_indices[0]], schedule[task_indices[1]] = schedule[task_indices[1]], schedule[task_indices[0]]
        return new_schedule
        
    def get_reward(self, schedule):
        self.env.reset()
        for action in schedule:
            self.env.step(action)
        score = self.env.get_raw_score()
        # print(f"Score: {score}, {schedule}")
        return score
        
    def get_action(self, observation, greedy = False):
        """
        Args:
            observation (dgl.DGLGraph): Observation graph
        Returns:
            int: Action
        """
        tasks_left = len(observation['task_to_task_select'].source_nodes) # number of tasks left
        step = len(observation['task']) - tasks_left
        # print(f"Getting action for step {step} - {self.schedule[step]}")
        # print(observation['task'].shape, len(observation['task']), tasks_left, step)
        return self.schedule[step], None, None, None
    
    def get_agent(self, observation, greedy = False):
        """
        Args:
            observation (dgl.DGLGraph): Observation graph
        Returns:
            int: Agent
        """
        tasks_left = len(observation['task_to_task_select'].source_nodes) # number of tasks left
        step = len(observation['task']) - tasks_left
        return self.schedule[step][1], None, None, None

    def get_agent_probs(self, observation, greedy = False, adaptive_temperature = False):
        """Get agent probabilities
        Args:
            x (dict): Input data Observation
            greedy (bool): Greedy action
            adaptive_temperature (bool): Adaptive temperature
        Returns:
            tuple: Tuple of agent ID, agent probabilities, agent log probabilities
        """
        tasks_left = len(observation['task_to_task_select'].source_nodes)
        step = len(observation['task']) - tasks_left
        agent_id = self.schedule[step][1]
        agent_probs = np.zeros(self.env.num_agents)
        agent_probs[agent_id] = 1.0
        log_probs = np.log(agent_probs + 1e-8)
        return agent_id, torch.tensor([agent_probs]), log_probs
    
    def get_task_probs(self, observation, agent_id, greedy = True):
        """Get task probabilities
        Args:
            x (dict): Input data Observation
        """
        tasks_left = len(observation['task_to_task_select'].source_nodes)
        step = len(observation['task']) - tasks_left
        task_id = self.schedule[step][0]
        task_probs = np.zeros(tasks_left)
        task_index = np.where(observation['task_to_task_select'].source_nodes == task_id)[0]
        task_probs[task_index] = 1.0
        log_probs = np.log(task_probs + 1e-8)
        return task_id, torch.tensor([task_probs]), log_probs