import subprocess
import re
import time
import copy
import argparse
import torch
import math
import torch.nn as nn
import torch.optim as optim
import os
from collections import namedtuple
import random
import torch.nn.functional as F
from copy import deepcopy
import torch_geometric.nn as gnn
from torch_geometric.data import Data
import matplotlib.pyplot as plt
import itertools
from torch_geometric.utils import scatter
import torch_geometric.utils as unn
from torch_geometric.utils import add_self_loops, to_dense_batch
import numpy as np 

# ToDo: Need to parameterize to support non 6b5b 
def get_synth_data_encoder(result):
    gate_count_calc = 0
    minterms = []
    gate_count = 0 
    
    # assuming result.stdout contains multiple lines of text
    lines = result.stdout.split('\n')
    
    def process_left_part(left_part, current_part):
        if "-" not in left_part:
            return [current_part + left_part]  # Base case: no more "-" in the left part

        # Find the first occurrence of "-"
        index = left_part.index("-")
        # Recursive case: replace the first "-" with "0" and "1"
        left_parts = []
        left_parts.extend(process_left_part(left_part[index + 1:], current_part + left_part[:index] + "0"))
        left_parts.extend(process_left_part(left_part[index + 1:], current_part + left_part[:index] + "1"))

        return left_parts
    
    i = -1
    for line in lines:
        line = line.strip()  # Remove leading/trailing whitespaces
        
        match = re.search(r'\.p (\d+)', line)
        if match:
            gate_count = int(match.group(1))  # group(1) refers to the first parenthesized subgroup in the pattern

        else:
            if not line or line[0] not in ['1', '0', '-']:
                continue  # Skip empty lines and lines that don't start with "1", "0", or "-"

            parts = line.split(" ")

            # Process the left part
            i += 1
            left_part = parts[0]
            minterms.append(left_part)
            gate_count_calc += 1 
    '''
    # add state info 
    i = -1
    for minterm in minterms:
        #print(minterm,flush=True)
        i += 1
        subterms = process_left_part(minterm,"")
        hyp_cnt = minterm.count("-")
        for subterm in subterms:
            ind = binary_to_base3(subterm)
            if(ind >= 0):
                encoder_state[i,binary_to_base3(subterm)] = 1/(2**hyp_cnt)
    '''
 
    assert gate_count_calc == gate_count, ValueError('incorrect parsing in encoder function where gate count is ' + str(gate_count))
    
    return gate_count #, encoder_state

# ToDo: Need to parameterize to support non 6b5b 
def get_synth_data_decoder(result):
    # decoder 
    gate_count_calc = 0
    minterms = []
    #decoder_state = torch.zeros(2**8,2**8)
    gate_count = 0

    # assuming result.stdout contains multiple lines of text
    lines = result.stdout.split('\n')
    
    def process_left_part(left_part, current_part):
        if "-" not in left_part:
            return [current_part + left_part]  # Base case: no more "-" in the left part

        # Find the first occurrence of "-"
        index = left_part.index("-")

        # Recursive case: replace the first "-" with "0" and "1"
        left_parts = []
        left_parts.extend(process_left_part(left_part[index + 1:], current_part + left_part[:index] + "0"))
        left_parts.extend(process_left_part(left_part[index + 1:], current_part + left_part[:index] + "1"))

        return left_parts
    
    i = -1
    for line in lines:
        line = line.strip()  # Remove leading/trailing whitespaces

        match = re.search(r'\.p (\d+)', line)
        if match:
            gate_count = int(match.group(1))  # group(1) refers to the first parenthesized subgroup in the pattern

        else:
            if not line or line[0] not in ['1', '0', '-']:
                continue  # Skip empty lines and lines that don't start with "1", "0", or "-"

            i += 1 
            parts = line.split(" ")

            # Process the left part
            minterms.append(parts[0])
            gate_count_calc += 1

    '''
    i = -1
    #non_zero_rows = []
    for minterm in minterms:
        i +=1
        subterms = process_left_part(minterm,"")
        hyp_cnt = minterm.count("-")
        for subterm in subterms:
            decoder_state[i,int(subterm,2)] = 1/(2**hyp_cnt)
    '''
    assert gate_count_calc == gate_count, ValueError('incorrect parsing in encoder function where gate count is ' + str(gate_count))
   
    return gate_count #, decoder_state

def log_weights_and_gradients(model):
    for name, param in model.named_parameters():
        #if param.requires_grad:
        if param.grad is not None:
            print(f"{name}: weights norm: {param.norm()}, gradients norm: {param.grad.norm()}")


def check_gradients(model):
    for name, param in model.named_parameters():
        if param.grad is None:
            print(f"param {name} grad is NONE!!!!!!!!!!!")


def l1_similarity(A, B, dim=1):
    return -torch.sum(torch.abs(A - B), dim=dim)

def run_espresso(fname):
    command = ['/home/xxx/RL/espresso-logic-master/bin/espresso', fname]
    result = subprocess.run(command, capture_output=True, text=True)
    return(result)

def copy_enc_dec(prefix):
    command = ['/home/xxx/RL/espresso-logic-master/bin/espresso', fname]
    result = subprocess.run(command, capture_output=True, text=True)
    return(result)


def calculate_hamming_distances(nodes):
    num_nodes = len(nodes)
    edge_indices = []
    edge_weights = []

    for i in range(num_nodes):
        for j in range(i + 1, num_nodes):
            hamming_distance = sum(c1 != c2 for c1, c2 in zip(nodes[i], nodes[j]))
            edge_indices.append((i, j))
            edge_indices.append((j, i))
            edge_weights.append(1/hamming_distance)
            edge_weights.append(1/hamming_distance)
            #edge_weights.append(1/hamming_distance)
            #edge_weights.append(1)
            #edge_weights.append(1)


    return edge_indices, edge_weights



def hamming_distance(str1, str2):
    return sum(c1 != c2 for c1, c2 in zip(str1, str2))

def binary_to_base3(binary_string):
    """
    Convert a binary string to an integer based on the given base 3 representation.
    """
    result = 0
    for i in range(0, len(binary_string), 2):
        if binary_string[i:i+2] == '00':
            result = result * 3 + 0
        elif binary_string[i:i+2] == '01':
            result = result * 3 + 1
        elif binary_string[i:i+2] == '11':
            result = result * 3 + 2
        else:
            result = -1
            break 
    return result

def base3_to_list(n):
    """
    Convert a number to a list of digits in base 3.
    """
    digits = []
    while n > 0:
        digits.append(n % 3)
        n //= 3
    return digits[::-1]


def decimal_to_base3(decimal, num_digits):
    base3 = []
    while decimal > 0:
        remainder = decimal % 3
        base3.append(str(remainder))
        decimal //= 3
    base3.reverse()
    base3_str = ''.join(base3)
    padded_base3_str = base3_str.zfill(num_digits)
    return padded_base3_str

def base3_to_binary(base3_str):
    binary_str = ''
    for digit in base3_str:
        if digit == '0':
            binary_str += '00'
        elif digit == '1':
            binary_str += '01'
        elif digit == '2':
            binary_str += '11'
    return binary_str



# Define the DQN agent
class DQNAgent:
    def __init__(self, state_dim, action_dim, hidden_dim, learning_rate, discount_factor,location_nodes,code_nodes,device):
        self.q_network = QValueGNN(state_dim, hidden_dim, action_dim,location_nodes,code_nodes,device)
        self.target_network = QValueGNN(state_dim, hidden_dim, action_dim,location_nodes,code_nodes,device)
        self.q_network.to(device)
        self.target_network.to(device)
        self.criterion = nn.MSELoss().to(device)


        self.optimizer = optim.Adam(self.q_network.parameters(), lr=learning_rate, weight_decay=1e-5)
        self.discount_factor = discount_factor
        
        # Apply the custom weight initialization
        self.q_network.initialize_weights()
        self.target_network.initialize_weights()
        self.device = device




    def act(self, state, epsilon, assigned_codes):
        state = state.to(self.device)
        self.q_network.eval()
        assigned_codes_tensor = torch.tensor(list(assigned_codes), dtype=torch.long).to(self.device)
        

        if torch.rand(1) < epsilon:
            # Epsilon-greedy exploration: choose a random valid action
            valid_actions = torch.arange(256).to(self.device)  
            # Mask invalid actions based on availability indicators from the state
            valid_actions_mask = state.x_codes_indicator.squeeze().bool().to(self.device)  
            valid_actions = valid_actions[valid_actions_mask]
            return valid_actions[torch.randint(0, len(valid_actions), (1,))]
        else:
            # Greedy action selection with masking of invalid actions
            with torch.no_grad():
                q_values = self.q_network(state)

                # Mask invalid actions based on availability indicators from the state
                mask = state.x_codes_indicator.squeeze().bool()
                q_values[~mask] = float('-inf')
                #print(q_values.size())
                return q_values.argmax().unsqueeze(0)



    def update(self, state, action, reward, next_state, done,episode,step,log_file_path="q_values.txt"):        
        #print('state is ' + str(state))
        action = action.to(self.q_network.device)
        self.q_network.train()
        self.target_network.eval()

        # Get Q-values for the current state
        q_values = self.q_network(state)

        # Get Q-values for the next state
        target_q_values = self.target_network(next_state).detach()

        # Mask invalid actions in the next state
        valid_actions_mask = next_state.x_codes_indicator.squeeze().bool()
        target_q_values[~valid_actions_mask] = float('-inf')

        # Select the action with the maximum Q-value in the next state (for the corresponding location)
        if not done:
            next_action = target_q_values.argmax()

        # Calculate the target value using Double DQN for the current location
        if(next_state.num_assigned < 3**5):
            target_value = reward + self.discount_factor * target_q_values[next_action].unsqueeze(0)
            #print('target val in les ' + str(target_value))
        else:
            target_value = reward.clone().detach().unsqueeze(0).unsqueeze(1)
            #print('target val in else ' + str(target_value))

        q_value = q_values[action]

        # Calculate the loss
        loss = self.criterion(q_value, target_value)
        self.optimizer.zero_grad()

        loss.backward()
        
        # Clip gradients 
        torch.nn.utils.clip_grad_norm_(self.q_network.parameters(), max_norm=5.0)
   
        self.optimizer.step()
        #print(f"Step: {step}, Reward: {reward}, Q-value: {q_value.item()}, Target Q-value: {target_value.item()}")

        return loss.item()

    def update_target_network(self):
        self.target_network.load_state_dict(self.q_network.state_dict())

def decimal_to_binary(decimal, num_bits):
    return format(decimal, f'0{num_bits}b')

def normalize_hamming_distance(hd, min_value=1, max_value=6):
    return (hd - min_value) / (max_value - min_value)


def print_attention_weights(gat_layer, edge_index, name):
    with torch.no_grad():
        attn_weights = gat_layer.attn_weights
        print(f"Attention weights for {name}:")
        print(attn_weights)

def create_code_edges(num_selected_codes):
    edge_index = []
    edge_attr = []

    if num_selected_codes > 1:
        for i in range(num_selected_codes):
            for j in range(i + 1, num_selected_codes):
                # Create bidirectional edges with uniform weights
                edge_index.append([i, j])
                edge_index.append([j, i])
                
                # Use uniform weight (e.g., 1.0)
                edge_attr.append(1.0)
                edge_attr.append(1.0)
    else:
        # Special case for a single node: create a self-loop
        edge_index.append([0, 0])
        edge_attr.append(1.0)

    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    edge_attr = torch.tensor(edge_attr, dtype=torch.float)
    
    return edge_index, edge_attr


class MLP(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, device):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim*2)
        self.fc3 = nn.Linear(hidden_dim*2, hidden_dim)
        self.ln1 = nn.LayerNorm(hidden_dim)
        self.ln2 = nn.LayerNorm(hidden_dim*2)
        self.ln3 = nn.LayerNorm(hidden_dim)
        self.dropout = nn.Dropout(0.3)
        self.device = device 

    def forward(self, x):
        x = F.relu(self.ln1(self.fc1(x)))
        x = self.dropout(x)
        x = F.elu(self.ln2(self.fc2(x)))
        x = self.dropout(x)
        x = F.elu(self.ln3(self.fc3(x)))
        return x

class QValueGNN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, location_nodes,code_nodes,device, num_heads=4):
        super(QValueGNN, self).__init__()
        self.num_heads = num_heads
        self.hidden_dim = hidden_dim
        self.device = device
             
        self.location_nodes = location_nodes
        self.code_nodes = code_nodes
        self.num_locations = len(location_nodes)

        self.location_mlp = MLP(input_dim, int(hidden_dim), hidden_dim,device)
        self.code_mlp = MLP(input_dim, int(hidden_dim), hidden_dim,device)
        
        # GAT Layer for combined
        self.gat_combined = gnn.GATv2Conv(hidden_dim, hidden_dim, heads=num_heads, concat=False,edge_dim=None)    
     
        # Fully Connected Layers (Output)
        self.fc1 = nn.Linear(hidden_dim+2, int(hidden_dim *0.5))  # Input is concatenated location and code representations        
        self.ln_fc1 = nn.LayerNorm(int(hidden_dim*0.5))
        self.fc2 = nn.Linear(int(hidden_dim * 0.5), 1)
        self.dropout = nn.Dropout(p=0.3)

        # Initialize the MLP transformations
        max_length = max(len(node) for node in self.location_nodes + self.code_nodes)

        # Padded location nodes with binary representation 
        padded_location_nodes = [list(map(int, node.zfill(max_length))) for node in self.location_nodes]
        # Padded code nodes with binary representation 
        padded_code_nodes = [list(map(int, node.zfill(max_length))) for node in self.code_nodes]

        # Combine binary features 
        self.x_binary = torch.tensor(padded_location_nodes + padded_code_nodes, dtype=torch.float32).to(self.device)

    def forward(self, data):

        edge_index = data.edge_index.T
        x_locations_indicator, x_codes_indicator = data.x_locations_indicator, data.x_codes_indicator
        location_edge_indicies, code_edge_indicies = data.loc_edge_inds, data.code_edge_indices         
        num_assigned = data.num_assigned
        num_locations = 3**5
        num_codes = 2**8
        input_dim = 10
        num_unassigned = num_codes-num_assigned
        
        # bi-partite edge index (num_locations x 2 )
        edge_index_inter_rev = edge_index[[1, 0],:]
        edge_index_inter = torch.cat([edge_index, edge_index_inter_rev], dim=1)
        edge_index_inter_combined = edge_index_inter
        
        indicators = x_codes_indicator.squeeze()
    
        all_loc = self.x_binary[:num_locations]
        code_nodes = self.x_binary[num_locations:]
        
        current_loc_idx = num_assigned
        
        # Current location features
        if(num_assigned < 3**5):
            #current_loc_feat = all_loc_feat[current_loc_idx].unsqueeze(0)  # (1, hidden_dim, input_dim)
            current_loc_raw = all_loc[current_loc_idx]  # (input_dim)
        else:
            #current_loc_feat = torch.zeros(1,self.hidden_dim,8).to(self.device)
            current_loc_raw = torch.zeros(10).to(self.device)

        # Generate history from edge_index
        assigned_loc_indices = edge_index[0][0:num_assigned]
        assigned_code_indices = edge_index[1][0:num_assigned]-3**5

        # Compute similarity between current location and historical locations based on raw binary features
        raw_binary_similarity_loc = l1_similarity(current_loc_raw, all_loc[assigned_loc_indices], dim=1)  # (num_assigned)

        # Normalize the raw binary similarity scores
        raw_binary_similarity_loc_normalized = F.softmax(raw_binary_similarity_loc, dim=0)  # (num_assigned)

        # Vectorized calculation of code similarity scores
        expanded_code_nodes = code_nodes.unsqueeze(1).expand(-1, num_assigned, -1)  # (num_codes, num_assigned, input_dim)
        expanded_history_codes = code_nodes[assigned_code_indices].unsqueeze(0).expand(num_codes, -1, -1)  # (num_codes, num_assigned, input_dim)
        similarity_scores = l1_similarity(expanded_code_nodes, expanded_history_codes, dim=2)  # (num_codes, num_assigned)
        weighted_similarity_scores = similarity_scores * raw_binary_similarity_loc_normalized.unsqueeze(0)  # (num_codes, num_assigned)
        raw_code_similarity_scores = weighted_similarity_scores.sum(dim=1, keepdim=True)  # (num_codes, 1)


        # Dynamically generate unused code node edges to current location node 
        unassigned_code_indices_tensor = torch.nonzero(x_codes_indicator[:] == 1)[:,0].unsqueeze(1).T
        num_unassigned = len(unassigned_code_indices_tensor.T)
        additional_edges = torch.stack([
            num_assigned*torch.ones(1,num_unassigned, dtype=torch.long).T.to(self.device), 
            unassigned_code_indices_tensor.T + 3**5
            ], dim=0)

        ae_temp =  additional_edges.squeeze(2)
        edge_index_inter_combined = torch.cat([location_edge_indicies,3**5+code_edge_indicies,edge_index_inter, ae_temp, ae_temp[[1,0],:]], dim=1)
      
        x_locations_mlp = self.location_mlp(all_loc)
        x_codes_mlp = self.code_mlp(code_nodes)
        
        x_combined_mlp = torch.cat([x_locations_mlp, x_codes_mlp],dim=0)


        x_updated = self.gat_combined(x_combined_mlp, edge_index_inter_combined,None)
     
        code_availability_mask = (1-indicators.float()) * -1e9 
        raw_code_similarity_scores = raw_code_similarity_scores + code_availability_mask.unsqueeze(1)

        # Normalize similarity scores
        raw_code_similarity_scores_normalized = F.softmax(raw_code_similarity_scores, dim=0)

        # Combine current location with code nodes, similarity scores, and historical features
        combined_feature = torch.cat([
            x_updated[num_locations:], raw_code_similarity_scores_normalized, torch.tensor(num_assigned/243).expand(num_codes,1).to(self.device)
        ], dim=1)  # (num_codes, input_dim * 2 + 1 + max_assigned * hidden_dim * input_dim)

        # Pass combined features through fully connected layers
        x = F.elu(self.ln_fc1(self.fc1(combined_feature)))   
        x = self.dropout(x)
        x = self.fc2(x)

        return x



    def initialize_weights(self):
        def init_weights(m):
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Conv1d):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.LayerNorm):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.MultiheadAttention):
                nn.init.xavier_uniform_(m.in_proj_weight)
                nn.init.constant_(m.in_proj_bias, 0)
                nn.init.xavier_uniform_(m.out_proj.weight)
                nn.init.constant_(m.out_proj.bias, 0)

        self.apply(init_weights)






  

class Environment:
    def __init__(self, location_nodes, code_nodes,device):
    
        self.device = device
        self.location_nodes = location_nodes
        self.code_nodes = code_nodes
        self.num_locations = len(location_nodes)
        self.num_codes = len(code_nodes)
        self.assigned_codes = []
        self.state = None
        self.done = False
        self.enc_gc = 0
        self.code_valid = torch.ones(self.num_codes, dtype=torch.float32)

        self.edge_index_inter = torch.tensor([], dtype=torch.long).to(self.device)
        
        #Initial graph hamming distance representation 
        temp_location_edge_indices, temp_location_edge_weights = calculate_hamming_distances(self.location_nodes)
        temp_code_edge_indices, temp_code_edge_weights = calculate_hamming_distances(self.code_nodes)
        
        self.location_edge_indices = torch.tensor(temp_location_edge_indices, dtype=torch.long).T.to(self.device)
          
        self.code_edge_indices = torch.tensor(temp_code_edge_indices, dtype=torch.long).T.to(self.device)

        self.edge_index_inter = self.edge_index_inter.to(self.device)


    def reset(self):
        self.assigned_codes = []
        self.state = self._create_state()
        self.done = False
        
        #self.num_aux_nodes = 0     


        self.edge_index_inter = torch.tensor([], dtype=torch.long).to(self.device)

        return self.state

    def step(self, action, prefix):

        # Assign the selected code to the next available location
        location_idx = len(self.assigned_codes)
        code_idx = action.item()
        self.assigned_codes.append(code_idx)
        
        # Update inter-set edge index
        self.edge_index_inter = torch.cat([self.edge_index_inter, torch.tensor([[location_idx, code_idx + self.num_locations]], dtype=torch.long).to(self.device)], dim=0).to(self.device)

        # Print assignment pairs to the file
        self._write_espresso_files(prefix)

        result = run_espresso('/home/xxx/RL/RL_start_empty/0913/' + prefix + '_temp_enc')
        self.enc_gc = get_synth_data_encoder(result)
        result = run_espresso('/home/xxx/RL/RL_start_empty/0913/' + prefix + '_temp_dec')
        self.dec_gc = get_synth_data_decoder(result)

        #self.enc_state = self.enc_state.to(self.device)
        #self.dec_state = self.dec_state.to(self.device)

        self.state = self._create_state()

        if len(self.assigned_codes) == self.num_locations:
            self.done = True
        return self.state, self.enc_gc, self.dec_gc, self.done

    def _write_espresso_files(self,my_prefix):
        assigned_locations = [self.location_nodes[i] for i in range(len(self.assigned_codes))]
        
        assigned_codes = [self.code_nodes[i] for i in self.assigned_codes]

        with open('/home/xxx/RL/RL_start_empty/0913/' + my_prefix + '_temp_enc', 'w') as f:
            f.write('.i 10\n')
            f.write('.o 8\n')
            f.write('.ilb x0 x1 x2 x3 x4 x5 x6 x7 x8 x9\n')
            f.write('.olb y0 y1 y2 y3 y4 y5 y6 y7\n')
            #pairing_str = [''.join(map(str, s1)) + ' ' + ''.join(map(str, s2)) for s1, s2 in zip(dwords_str, cwords_str)] 
            pairing_str = [loc + ' ' + code for loc, code in zip(assigned_locations, assigned_codes)]
            for string in pairing_str:
                f.write(string + "\n")
            f.write('--------10 --------' + "\n")
            f.write('------10-- --------' + "\n")
            f.write('----10---- --------' + "\n")
            f.write('--10------ --------' + "\n")
            f.write('10-------- --------' + "\n")

            # write remainig lcoations as dont care
            for i in range(len(self.assigned_codes), len(self.location_nodes)):
                f.write(f"{str(self.location_nodes[i])} --------\n")

        missing_codes = torch.tensor([1 if i not in self.assigned_codes else 0 for i in range(self.num_codes)], dtype=torch.float32).unsqueeze(1)

        with open('/home/xxx/RL/RL_start_empty/0913/' + my_prefix + '_temp_dec', 'w') as f:       
            f.write('.i 8\n')
            f.write('.o 10\n')
            f.write('.ilb y0 y1 y2 y3 y4 y5 y6 y7\n')
            f.write('.olb x0 x1 x2 x3 x4 x5 x6 x7 x8 x9\n')
            #pairing_str = [''.join(map(str, s1)) + ' ' + ''.join(map(str, s2)) for s1, s2 in zip(cwords_str, dwords_str)]
            pairing_str = [loc + ' ' + code for loc, code in zip(assigned_codes, assigned_locations)]
            for string in pairing_str:
                f.write(string + "\n")
            for i, code in enumerate(missing_codes):
                if code.item() == 1:  # Check if the code is missing
                    f.write(f"{i:08b}  ----------\n")  # Directly format the integer as a 8-bit binary string

   
    def _create_state(self):

        # Indicator features (as separate tensors)
        assignment_indicators = torch.zeros(self.num_locations,dtype=torch.float32).to(self.device)
        if(len(self.assigned_codes) < 3**5):
            assignment_indicators[len(self.assigned_codes)] = torch.tensor(1)
            assignment_indicators[len(self.assigned_codes)+1:] = torch.tensor(-1)
            assignment_indicators = assignment_indicators.unsqueeze(1)
        else:
            assignment_indicators = assignment_indicators.unsqueeze(1)
             

        availability_indicators = torch.tensor([1 if i not in self.assigned_codes else 0 for i in range(self.num_codes)], dtype=torch.float32).unsqueeze(1).to(self.device)

        # Add features for auxiliary nodes (all zeros)
        #num_auxiliary_nodes = self.num_aux_nodes

        #auxiliary_node_features = torch.cat((self.auxiliary_nodes_enc_features, self.auxiliary_nodes_dec_features), dim=0).to(self.device)

        # Concatenate the binary features and the auxiliary node features (along dim=0)
        edge_index_inter = self.edge_index_inter

        if edge_index_inter.ndim == 1:
            edge_index_inter = edge_index_inter.unsqueeze(0)

        edge_index = edge_index_inter
        #torch.cat((edge_index_inter, self.edge_index_intra), dim=0)

        # Create the Data object with separate x tensors for binary and indicator features
        state = Data(x_locations_indicator=assignment_indicators,
                     x_codes_indicator=availability_indicators,
                     edge_index=edge_index,loc_edge_inds = self.location_edge_indices, code_edge_indices = self.code_edge_indices, 
                     num_assigned =len(self.assigned_codes))

        return state.to(self.device)

def main():
    seed = 22
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    parser = argparse.ArgumentParser(description='Generate training data.')
    parser.add_argument('--num_episodes', type=int, default=15000, help='Number of episodes to simulate')
    parser.add_argument('--file_prefix', type=str, default='data', help='Prefix for output file names.')
    args = parser.parse_args()
    #print(args.file_prefix,flush=True)
    
    directory = args.file_prefix + "/"
    if not os.path.exists(directory):
        os.makedirs(directory)

    args.file_prefix = directory + args.file_prefix

    # Set hyperparameters and create the environment
    num_locations = 3**5
    num_codes = 2**8
    num_symbols = 5
    symbol_length = 2
    symbol_map = {0: '00', 1: '01', 2: '11'}
    code_length = 8


    hidden_dim = 256
    learning_rate = 0.00007
    discount_factor = 0.999
    num_episodes = 100000
    epsilon_start = 1.0
    epsilon_end = 0.001
    epsilon_decay = 0.995

    # Binary strings for codes/locations 
    #code_nodes = [decimal_to_binary(i,code_length) for i in code_nodes_int]
    location_nodes = [base3_to_binary(decimal_to_base3(i, num_symbols)) for i in range(num_locations)]
    code_nodes = [decimal_to_binary(i, code_length) for i in range(num_codes)]

    # create environment 
    env = Environment(location_nodes, code_nodes, device)

    # Create the DQN agent
    #state_dim = 8
    #action_dim = 139   
    state_dim = num_symbols*symbol_length
    action_dim = num_codes

    agent = DQNAgent(state_dim, action_dim, hidden_dim, learning_rate, discount_factor, location_nodes,code_nodes,device)

    epsilon = epsilon_start
    for episode in range(num_episodes):
        state = env.reset()
        done = False
        enc_gc = 0
        dec_gc = 0
        rewards = []
        stat_array = []
        reward = 0
        # start episode by assigning the 0 element to both 
        lcv = -1
        losses = []
        while not done:
            lcv += 1 
            
            # for first action to be 0 (trial)
            if (lcv == 0):
                action = torch.tensor([0])
            else:
                action = agent.act(state, epsilon,env.assigned_codes)
            
            next_state, next_enc_gc,next_dec_gc,done = env.step(action,args.file_prefix) 
         
            if(lcv > 0):                
                # get gate count TD
                td_gc = next_enc_gc+next_dec_gc-enc_gc-dec_gc
                reward = 1.0*(torch.tensor(2,device=device)-td_gc)

                rewards.append(reward)
                step_loss = agent.update(state, action, reward, next_state, done,episode,lcv,log_file_path=args.file_prefix + "_qvalues.txt")
                losses.append(step_loss)
               
                stat_array.append((episode,epsilon,location_nodes[lcv], next_enc_gc, next_dec_gc, step_loss, code_nodes[action],reward))
          
            state = next_state
            enc_gc = next_enc_gc
            dec_gc = next_dec_gc 
        print('episode is ' + str(episode) + ' epsilon is ' +  str(epsilon) + ' gate count is ' + str(enc_gc) + ' and ' + str(dec_gc) + ' and rwd is ' + str(sum(rewards)) + ' and loss mean is ' + str(torch.mean(torch.tensor(losses))),flush=True)

    
        with open(args.file_prefix + "_episode_results_" + str(episode) + ".txt", "w") as file1:
            for i in range(len(stat_array)):
                epi_val = stat_array[i][0]
                eps_val = stat_array[i][1]                
                enc_val = stat_array[i][2]
                gc_enc_val = stat_array[i][3]
                gc_dec_val = stat_array[i][4]
                loss_val = stat_array[i][5]
                action_val = stat_array[i][6]
                rwd_val = stat_array[i][7]
                file1.write(f"{eps_val:.4f}\t{enc_val}\t{gc_enc_val}\t{gc_dec_val}\t{loss_val:.3f}\t{action_val}\t{rwd_val}\n")
        

        if(episode % 50 == 0):
            agent.update_target_network()

        if(episode % 100 == 0):
            torch.save(agent.q_network, args.file_prefix + "_weights_episode_" + str(episode) + ".pt")

        epsilon = max(epsilon_end, epsilon_decay * epsilon)
        env._write_espresso_files(args.file_prefix + '_'  + str(episode))

        # save weights
        #torch.save(agent.q_network.attn_inter_loc_code1.att, args.file_prefix + "_attn_inter_loc_code_weights_episode_" + str(episode) + ".pt") 
        #visualize_graph(state)

if __name__ == '__main__':
    main()
