import os
import time
import string
from typing import Callable
import sys
from datetime import datetime

from absl import logging

from circuit_training.environment import environment as oss_environment
from circuit_training.learning import agent
from circuit_training.learning import learner as learner_lib

import reverb
import tensorflow as tf
from tensorflow.python.summary.summary_iterator import summary_iterator
physical_devices = tf.config.experimental.list_physical_devices('GPU')
if len(physical_devices) > 0:
    tf.config.experimental.set_visible_devices(physical_devices[0], 'GPU')
    tf.config.experimental.set_memory_growth(physical_devices[0], True)

import numpy as np
import random

from tf_agents.experimental.distributed import reverb_variable_container
from tf_agents.networks import network
from tf_agents.replay_buffers import reverb_replay_buffer
from tf_agents.train import learner as actor_learner
from tf_agents.train import triggers
from tf_agents.train.utils import spec_utils
from tf_agents.train.utils import train_utils
from tf_agents.utils import common
import keras
from tf_agents.train import learner


def train(
    root_dir: str,
    parent_chkpt_dir: str,
    parent_policy_dir: str,
    strategy: tf.distribute.Strategy,
    replay_buffer_server_address: str,
    variable_container_server_address: str,
    create_env_fn: Callable[[], oss_environment.CircuitEnv],
    sequence_length: int,
    actor_net: network.Network,
    value_net: network.Network,
    use_grl: bool = True,
    per_replica_batch_size: int = 32,
    num_epochs: int = 4,
    num_iterations: int = 10000,
    num_episodes_per_iteration: int = 1024,
    allow_variable_length_episodes: bool = False,
    population = 5,
    generations = 25) -> None:
  """Creates a NEW PPO AGENT using on Circuit Design and infers on a Different Design.  """
  print("Started creating the environment!")


  class Node:
    def __init__(self, father_model, mother_model, cost=0, parent=None):
        self.father_model = father_model
        self.mother_model = mother_model
        self.parent = parent
        self.children = []
        self.cost = cost

    def add_child(self, child):
        self.children.append(child)

    def is_root(self):
        return self.parent is None

    def __repr__(self):
        parent_info = "Root Node" if self.is_root() else f"Child of ({self.parent.father_model}, {self.parent.mother_model})"
        return f"Node({self.father_model}, {self.mother_model}, Cost: {self.cost}, {parent_info})"

    def print_tree(self, level=0):
        indent = " " * (level * 4)
        print(f"{indent}{self}")
        for child in self.children:
            child.print_tree(level + 1)


  def count_nodes(node):
    return 1 + sum(count_nodes(child) for child in node.children)

  def print_path_to_node(node):
    path = []
    while node:
        path.append(node)
        node = node.parent
    path.reverse()
    for n in path:
        print(n)

  def find_lowest_cost_node(node):
    if node.cost < 0:
        return None
    lowest_cost_node = node
    for child in node.children:
        potential_lowest = find_lowest_cost_node(child)
        if potential_lowest and (lowest_cost_node.cost < 0 or potential_lowest.cost < lowest_cost_node.cost):
            lowest_cost_node = potential_lowest
    return lowest_cost_node


  def rollout(node, rollout_count):
    current_node = node
    added_count = 0
    while added_count < rollout_count and count_nodes(node) < 200:
        mutated_father_model, mutated_mother_model= mutations(current_node.father_model,current_node.mother_model)

        env = create_env_fn()
        _, action_tensor_spec, time_step_tensor_spec = (spec_utils.get_tensor_specs(env))

        with strategy.scope():
            train_step = train_utils.create_train_step()
            creat_agent_fn = agent.create_circuit_ppo_grl_agent
            tf_agent_1 = creat_agent_fn(
                train_step,
                action_tensor_spec,
                time_step_tensor_spec,
                mutated_father_model,
                mutated_mother_model,
                strategy,
                )
            tf_agent_1.initialize()

        tf_agent_track.append(tf_agent_1)
        obs = env.reset()
        count = 1
        last = False

        while not obs.is_last():
            action = tf_agent_1.policy.action(obs)
            obs= env.step(action.action)
            if obs.is_last():
                last = True
            count+=1

        cost = float((0.5*env.call_analytical_placer_and_get_cost()[1]['density'])) + float((0.5*env.call_analytical_placer_and_get_cost()[1]['congestion'])) + float(env.call_analytical_placer_and_get_cost()[1]['wirelength'])
        cost_track.append(cost)
        timestamp = datetime.now().strftime("%d %H:%M:%S")
        cost_time_dict[timestamp] = cost

        child_node = Node(mutated_father_model, mutated_mother_model, cost, current_node)
        current_node.add_child(child_node)
        current_node = child_node
        added_count += 1

  env = create_env_fn()
  _, action_tensor_spec, time_step_tensor_spec = (spec_utils.get_tensor_specs(env))
  cost_time_dict = {}

  # Create the agent.
  with strategy.scope():
    train_step = train_utils.create_train_step()
    creat_agent_fn = agent.create_circuit_ppo_grl_agent
    tf_agent = creat_agent_fn(
        train_step,
        action_tensor_spec,
        time_step_tensor_spec,
        actor_net,
        value_net,
        strategy,
    )
    tf_agent.initialize()

  print("Start loading the latest Checkpoints!")
  parent_checkpoint_dir = parent_chkpt_dir

  # Now load the checkpointer
  load_checkpointer = common.Checkpointer(
    ckpt_dir=parent_checkpoint_dir,
    max_to_keep=20,
    agent=tf_agent,
    policy=tf_agent.policy)

  ## load the checkpointer
  load_checkpointer.initialize_or_restore()

  global_step = tf.compat.v1.train.get_global_step()
  print("The global step is:", global_step)

  obs = env.reset()
  done = False
  count = 1

  wirelength_track = []
  density_track = []
  congestion_track = []
  valid_placement = []
  tf_agent_track = []
  actor_net_track = []
  value_net_track = []
  cost_track = []
  last = False

  while not obs.is_last():
    action = tf_agent.policy.action(obs)
    obs= env.step(action.action)
    if obs.is_last():
        last = True
    count+=1

  print("******* Base Model info *******")
  print("Congestion values", env.call_analytical_placer_and_get_cost()[1]['congestion'])
  print("Wirelength values", env.call_analytical_placer_and_get_cost()[1]['wirelength'])
  print("Density values", env.call_analytical_placer_and_get_cost()[1]['density'])

  wirelength_track.append(env.call_analytical_placer_and_get_cost()[1]['wirelength'])
  density_track.append(env.call_analytical_placer_and_get_cost()[1]['density'])
  congestion_track.append(env.call_analytical_placer_and_get_cost()[1]['congestion'])
  parent_root_cost = float((0.5*density_track[-1])) + float((0.5*congestion_track[-1])) + float(wirelength_track[-1])


  valid_placement.append(last)
  tf_agent_track.append(tf_agent)
  actor_net_track.append(actor_net)
  value_net_track.append(value_net)
  cost_track.append(parent_root_cost)
  timestamp = datetime.now().strftime("%d %H:%M:%S")
  cost_time_dict[timestamp] = parent_root_cost

  print("Actor Net: ", actor_net._shared_network._model._policy_location_head.summary())


  # Multiply mutation operations.
  def mutation_3(policy_network, value_network):
    # Change the weights of a particular conv2d layer.
    network = policy_network._shared_network._model._policy_location_head
    conv_locations = []
    for i, layer in enumerate(network.layers):
      if(isinstance(layer, keras.layers.Conv2DTranspose)):
        conv_locations.append(i)
    random_conv_location = random.choice(conv_locations)
    for i, layer in enumerate(network.layers):
      if i== random_conv_location:
        generated_weights = np.random.uniform(0, 1, size=layer.get_weights()[0].shape)
        new_weights = np.multiply(layer.get_weights()[0], generated_weights)
        assign_weights = [new_weights, layer.get_weights()[1]] # Assigning both weights and bias.
        network.layers[i].set_weights(assign_weights)
    policy_network._shared_network._model._policy_location_head = network
    return policy_network, value_network

  # Subtract mutation operations.
  def mutation_2(policy_network, value_network):
    network = policy_network._shared_network._model._policy_location_head
    conv_locations = []
    for i, layer in enumerate(network.layers):
      if(isinstance(layer, keras.layers.Conv2DTranspose)):
        conv_locations.append(i)
    random_conv_location = random.choice(conv_locations)
    for i, layer in enumerate(network.layers):
      if i== random_conv_location:
        generated_weights = np.random.uniform(0, 1, size=layer.get_weights()[0].shape)
        new_weights = np.subtract(layer.get_weights()[0], generated_weights)
        assign_weights = [new_weights, layer.get_weights()[1]]
        network.layers[i].set_weights(assign_weights)
    policy_network._shared_network._model._policy_location_head = network
    return policy_network, value_network

  # Add mutation operations.
  def mutation_1(policy_network, value_network):
    network = policy_network._shared_network._model._policy_location_head
    conv_locations = []
    for i, layer in enumerate(network.layers):
      if(isinstance(layer, keras.layers.Conv2DTranspose)):
        conv_locations.append(i)
    random_conv_location = random.choice(conv_locations)
    for i, layer in enumerate(network.layers):
      if i== random_conv_location:
        generated_weights = np.random.uniform(0, 1, size=layer.get_weights()[0].shape)
        new_weights = np.add(layer.get_weights()[0], generated_weights)
        assign_weights = [new_weights, layer.get_weights()[1]] # Assigning both weights and bias.
        network.layers[i].set_weights(assign_weights)
    policy_network._shared_network._model._policy_location_head = network
    return policy_network, value_network

  def mutations(actor_model, value_model):
    x = random.randint(0, 2)
    if x ==0:
      child_actor_model, child_value_model = mutation_1(actor_model, value_model)
    elif x == 1:
      child_actor_model, child_value_model = mutation_2(actor_model, value_model)
    elif x== 2:
      child_actor_model, child_value_model = mutation_3(actor_model, value_model)
    return child_actor_model, child_value_model


  def select_explore_or_exploit(node):
    if random.random() < 0.9:
        result_node = find_lowest_cost_node(node)
        if result_node is None:  # If no valid node found
            return node, "No action (no valid children)"
        action = "Exploitation"
    elif node.children:
        valid_children = [child for child in node.children if child.cost >= 0]
        if not valid_children:
            return node, "No action (no valid children)"
        result_node = random.choice(valid_children)
        action = "Exploration"
    else:
        return node, "No action (no children)"  # In case there are no children to choose from
    return result_node, action


  parent_actor_model = actor_net_track[0]
  parent_value_model= value_net_track[0]
  parent_root = Node(parent_actor_model, parent_value_model, parent_root_cost)
  rollout(parent_root, 10)

  current_node = parent_root
  while count_nodes(parent_root) < 200:
    selected_node, action = select_explore_or_exploit(current_node)
    rollout(selected_node, 10)  #
    current_node = selected_node  # Update current node to latest expanded

  # Find and print the path to the node with the lowest cost
  lowest_cost_node = find_lowest_cost_node(parent_root)
  print("Lowest Cost Node:")
  print(lowest_cost_node)
  print("Path to the Lowest Cost Node:")
  print_path_to_node(lowest_cost_node)
  print("-----------")
  print("Cost_Time_Dict", cost_time_dict)
  print("-----------")
  print("Cost Track:", cost_track)
