import torch
import torch.nn as nn
from torch.nn import functional as F
import numpy as np
import csv
import os

# Load pretrained LLM and Reward Model
from models.reward_model import RewardModel
from models.analogGPT_valuehead import GPTLanguageModel

torch.manual_seed(1337)

# Define the rule-based checkers (functions are as defined earlier)
# Define the base names for devices
nm_np_bases = ["{}_D", "{}_G", "{}_S", "{}_B"]
npn_pnp_bases = ["{}_C", "{}_B", "{}_E"]
c_r_l_i_bases = ["{}_P", "{}_N"]
xor_bases = ["{}_A", "{}_B", "{}_VDD", "{}_VSS", "{}_Y"]
pfd_bases = ["{}_A", "{}_B", "{}_QA", "{}_QB", "{}_VDD", "{}_VSS"]
inverter_bases = ["{}_A", "{}_Q", "{}_VDD", "{}_VSS"]
transmission_gate_bases = ["{}_A", "{}_B", "{}_C", "{}_VDD", "{}_VSS"]

# Define the ports
ports = []
for prefix in ["VIN", "IIN", "VOUT", "IOUT"]:
    for i in range(1, 6):
        ports.append(f"{prefix}{i}")

for prefix in ["VB", "IB"]:
    for i in range(1, 11):
        ports.append(f"{prefix}{i}")

for prefix in ["VCONT", "VCLK", "VCM", "VREF", "IREF", "VRF", "VLO", "VIF", "VBB"]:
    for i in range(1, 6):
        ports.append(f"{prefix}{i}")

for prefix in ["LOGICA", "LOGICB", "LOGICD", "LOGICF", "LOGICG", "LOGICQ", "LOGICQA", "LOGICQB"]:
    for i in range(1, 3):
        ports.append(f"{prefix}{i}")

# Additional entries
ports.append("VDD")
ports.append("VSS")

def check_sequence_first_test(tokens):
    devices = {
        "NM": nm_np_bases,
        "PM": nm_np_bases,
        "NPN": npn_pnp_bases,
        "PNP": npn_pnp_bases,
        "R": c_r_l_i_bases,
        "C": c_r_l_i_bases,
        "L": c_r_l_i_bases,
        "DIO": c_r_l_i_bases,
        "XOR": xor_bases,
        "PFD": pfd_bases,
        "INVERTER": inverter_bases,
        "TRANSMISSION_GATE": transmission_gate_bases
    }
    
    violations = []

    # Function to find the corresponding base names
    def get_base_names(token):
        for key, bases in devices.items():
            if token.startswith(key):
                return key, bases
        return None, None

    for i in range(len(tokens) - 1):
        start_token = tokens[i]
        end_token = tokens[i + 1]
        
        start_device_prefix, start_pin_bases = get_base_names(start_token)
        end_device_prefix, end_pin_bases = get_base_names(end_token)
        
        if not start_device_prefix and not end_device_prefix:
            continue
        
        # Extract device identifier and check if the start token or end token is a device
        if start_device_prefix:
            start_device_identifier = start_token[:len(start_device_prefix)]
            start_device_number = start_token[len(start_device_prefix):].split('_')[0]
            full_start_device_id = f"{start_device_prefix}{start_device_number}"

            is_start_pin = any(start_token == base.format(full_start_device_id) for base in start_pin_bases)
            if not is_start_pin:
                is_end_pin = any(end_token == base.format(full_start_device_id) for base in start_pin_bases)
                if not is_end_pin:
                    violations.append(f"Device {start_token} should connect to any of its pins ({', '.join(base.format(full_start_device_id) for base in start_pin_bases)}) first before connecting to {end_token}")

        if end_device_prefix:
            end_device_identifier = end_token[:len(end_device_prefix)]
            end_device_number = end_token[len(end_device_prefix):].split('_')[0]
            full_end_device_id = f"{end_device_prefix}{end_device_number}"

            is_end_pin = any(end_token == base.format(full_end_device_id) for base in end_pin_bases)
            if not is_end_pin:
                is_start_pin = any(start_token == base.format(full_end_device_id) for base in end_pin_bases)
                if not is_start_pin:
                    violations.append(f"Device {end_token} should connect to any of its pins ({', '.join(base.format(full_end_device_id) for base in end_pin_bases)}) first before connecting to {start_token}")

    return violations

def check_sequence_second_test(tokens):
    devices = {
        "NM": nm_np_bases,
        "PM": nm_np_bases,
        "NPN": npn_pnp_bases,
        "PNP": npn_pnp_bases,
        "R": c_r_l_i_bases,
        "C": c_r_l_i_bases,
        "L": c_r_l_i_bases,
        "DIO": c_r_l_i_bases,
        "XOR": xor_bases,
        "PFD": pfd_bases,
        "INVERTER": inverter_bases,
        "TRANSMISSION_GATE": transmission_gate_bases
    }
    
    violations = []
    pin_presence = {device_prefix: set() for device_prefix in devices}

    # Function to find the corresponding base names
    def get_base_names(token):
        for key, bases in devices.items():
            if token.startswith(key):
                return key, bases
        return None, None

    for token in tokens:
        device_prefix, pin_bases = get_base_names(token)
        if not device_prefix:
            continue
        
        # Extract device identifier and check if the token is a pin
        device_number = token[len(device_prefix):].split('_')[0]
        full_device_id = f"{device_prefix}{device_number}"

        is_pin = any(token == base.format(full_device_id) for base in pin_bases)
        if is_pin:
            pin_presence[device_prefix].add(token)

    # Check for floating pins
    reported_violations = set()
    for device_prefix, pin_bases in devices.items():
        for token in pin_presence[device_prefix]:
            device_number = token[len(device_prefix):].split('_')[0]
            full_device_id = f"{device_prefix}{device_number}"
            required_pins = set(pin_base.format(full_device_id) for pin_base in pin_bases)
            if required_pins - pin_presence[device_prefix]:
                missing_pins = required_pins - pin_presence[device_prefix]
                if full_device_id not in reported_violations:
                    violations.append(f"Device {full_device_id} is missing pins: {', '.join(missing_pins)}")
                    reported_violations.add(full_device_id)

    return violations

def check_sequence_third_test(tokens):
    port_connections = {}
    unique_connections = {}
    reported_pairs = set()
    violations = []

    for i in range(len(tokens) - 1):
        start_token = tokens[i]
        end_token = tokens[i + 1]

        # Record each unique token's connection to other unique tokens
        if start_token not in unique_connections:
            unique_connections[start_token] = set()
        unique_connections[start_token].add(end_token)

        if end_token not in unique_connections:
            unique_connections[end_token] = set()
        unique_connections[end_token].add(start_token)

        # Direct port connections
        if start_token in ports:
            if end_token in ports:
                violations.append(f"Invalid connection: both {start_token} and {end_token} are ports.")
                continue
            if end_token in port_connections:
                port_connections[end_token].add(start_token)
            else:
                port_connections[end_token] = {start_token}
        elif end_token in ports:
            if start_token in port_connections:
                port_connections[start_token].add(end_token)
            else:
                port_connections[start_token] = {end_token}

    # Report direct connections to multiple ports
    for pin, connected_ports in port_connections.items():
        if len(connected_ports) > 1:
            violations.append(f"Pin {pin} is connected to multiple ports: {', '.join(connected_ports)}")

    # Second check: focus on start_token and end_token being both pins
    for start_token in unique_connections:
        for end_token in unique_connections[start_token]:
            if any(start_token.endswith(suffix) for suffix in ['_D', '_G', '_S', '_B', '_C', '_E', '_P', '_N', '_A', '_Q', '_QA', '_QB', '_VDD', '_VSS', '_Y']) and \
               any(end_token.endswith(suffix) for suffix in ['_D', '_G', '_S', '_B', '_C', '_E', '_P', '_N', '_A', '_Q', '_QA', '_QB', '_VDD', '_VSS', '_Y']):
                # Check the connections of both pins
                start_token_ports = port_connections.get(start_token, set())
                end_token_ports = port_connections.get(end_token, set())
                if start_token_ports and end_token_ports:
                    common_ports = start_token_ports.intersection(end_token_ports)
                    if len(common_ports) < len(start_token_ports.union(end_token_ports)):
                        # Use a frozenset to ensure that (start_token, end_token) and (end_token, start_token) are considered the same
                        pair = frozenset([start_token, end_token])
                        if pair not in reported_pairs:
                            violations.append(f"Pins {start_token} and {end_token} are connected to different ports: {', '.join(start_token_ports)} and {', '.join(end_token_ports)}")
                            reported_pairs.add(pair)

    return violations

def check_sequence(tokens):
    # Split sequence into tokens
    # tokens = sequence.split()
    # print(tokens)

    # Truncate the sequence at the first occurrence of "TRUNCATE"
    if "TRUNCATE" in tokens:
        tokens = tokens[:tokens.index("TRUNCATE")]

    # Perform the rule-based checks
    violations = []
    violations.extend(check_sequence_first_test(tokens))
    violations.extend(check_sequence_second_test(tokens))
    violations.extend(check_sequence_third_test(tokens))

    # Return violations
    return violations

# Hyperparameters
# Define the device base names
nm_np_bases = ["{}_D", "{}_G", "{}_S", "{}_B"]
npn_pnp_bases = ["{}_C", "{}_B", "{}_E"]
c_r_l_i_bases = ["{}_P", "{}_N"]
xor_bases = ["{}_A", "{}_B", "{}_VDD", "{}_VSS", "{}_Y"]
pfd_bases = ["{}_A", "{}_B", "{}_QA", "{}_QB", "{}_VDD", "{}_VSS"]
inverter_bases = ["{}_A", "{}_Q", "{}_VDD", "{}_VSS"]
transmission_gate_bases = ["{}_A", "{}_B", "{}_C", "{}_VDD", "{}_VSS"]

# Initialize the list of NM, PM, C, R, L, I, VIN, VB, VOUT devices, and additional entries
devices = []
for prefix in ["NM", "PM"]:
    for i in range(1, 26):
        devices.append(f"{prefix}{i}")
        for base in nm_np_bases:
            devices.append(base.format(f"{prefix}{i}"))

for prefix in ["NPN", "PNP"]:
    for i in range(1, 26):
        devices.append(f"{prefix}{i}")
        for base in npn_pnp_bases:
            devices.append(base.format(f"{prefix}{i}"))

for i in range(1, 26):
    devices.append(f"R{i}")
    for base in c_r_l_i_bases:
        devices.append(base.format(f"R{i}"))

for i in range(1, 26):
    devices.append(f"C{i}")
    for base in c_r_l_i_bases:
        devices.append(base.format(f"C{i}"))

for i in range(1, 26):
    devices.append(f"L{i}")
    for base in c_r_l_i_bases:
        devices.append(base.format(f"L{i}"))

for i in range(1, 26):
    devices.append(f"DIO{i}")
    for base in c_r_l_i_bases:
        devices.append(base.format(f"DIO{i}"))

for i in range(1, 6):
    devices.append(f"XOR{i}")
    for base in xor_bases:
        devices.append(base.format(f"XOR{i}"))

for i in range(1, 6):
    devices.append(f"PFD{i}")
    for base in pfd_bases:
        devices.append(base.format(f"PFD{i}"))

for i in range(1, 11):
    devices.append(f"INVERTER{i}")
    for base in inverter_bases:
        devices.append(base.format(f"INVERTER{i}"))

for i in range(1, 11):
    devices.append(f"TRANSMISSION_GATE{i}")
    for base in transmission_gate_bases:
        devices.append(base.format(f"TRANSMISSION_GATE{i}"))

# for i in range(1, 26):
#     devices.append(f"I{i}")
#     for base in c_r_l_i_bases:
#         devices.append(base.format(f"I{i}"))

for i in range(1, 6):
    devices.append(f"VIN{i}")

for i in range(1, 6):
    devices.append(f"IIN{i}")

for i in range(1, 6):
    devices.append(f"VOUT{i}")

for i in range(1, 6):
    devices.append(f"IOUT{i}")

for i in range(1, 11):
    devices.append(f"VB{i}")

for i in range(1, 11):
    devices.append(f"IB{i}")

for i in range(1, 6):
    devices.append(f"VCONT{i}")

for i in range(1, 6):
    devices.append(f"VCLK{i}")

for i in range(1, 6):
    devices.append(f"VCM{i}")

for i in range(1, 6):
    devices.append(f"VREF{i}")

for i in range(1, 6):
    devices.append(f"IREF{i}")

for i in range(1, 6):
    devices.append(f"VRF{i}")

for i in range(1, 6):
    devices.append(f"VLO{i}")

for i in range(1, 6):
    devices.append(f"VIF{i}")

for i in range(1, 6):
    devices.append(f"VBB{i}")

for i in range(1, 3):
    devices.append(f"LOGICA{i}")

for i in range(1, 3):
    devices.append(f"LOGICB{i}")

for i in range(1, 3):
    devices.append(f"LOGICD{i}")

for i in range(1, 3):
    devices.append(f"LOGICF{i}")

for i in range(1, 3):
    devices.append(f"LOGICG{i}")

for i in range(1, 3):
    devices.append(f"LOGICQ{i}")

for i in range(1, 3):
    devices.append(f"LOGICQA{i}")

for i in range(1, 3):
    devices.append(f"LOGICQB{i}")

# Adding the additional entries
additional_entries = ["VDD", "VSS", "TRUNCATE"]
devices.extend(additional_entries)

# Create a mapping from device names to integers and vice versa
stoi = { device: i for i, device in enumerate(devices) }
itos = { i:device for i,device in enumerate(devices) }
vocab_size = len(devices)
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
# decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string
decode = lambda l: [itos[i] for i in l]

block_size = 1024
learning_rate = 1e-5
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
# device = 'cpu' 
print(device)
n_embd = 384
n_head = 6
n_layer = 6
dropout = 0.2
num_labels = 2  # Number of classes in the reward model
# Training loop
epochs = 250
ppo_epochs = 4
beta = 0.01  # KL divergence penalty scaling factor
batch_size = 64
minibatch_size = 64
filename = 'RLHF_ppo'
# Define the constant context
context = torch.full((1, 1), 1027, dtype=torch.long, device=device)
ratio_threshold = 10.0

# Initialize the models
policy_model = GPTLanguageModel(vocab_size, n_embd, block_size, n_head, n_layer, dropout).to(device)
reward_model = RewardModel(vocab_size, n_embd, n_head, n_layer, num_labels, block_size, dropout).to(device)
baseline_model = GPTLanguageModel(vocab_size, n_embd, block_size, n_head, n_layer, dropout).to(device)

# Load the pretrained weights
policy_model.load_state_dict(torch.load('Train_Align_Val_Align_Aug_temp.pth'), strict=False)
baseline_model.load_state_dict(torch.load('Train_Align_Val_Align_Aug_temp.pth'), strict=False)
reward_model.load_state_dict(torch.load('Train_Align_Val_Align_Aug_temp_withlabel.pth'))

# # Print the weights of the value head
# print("Value Head Weights:")
# print(policy_model.value_head.summary.weight)

# # Print the biases of the value head
# print("Value Head Biases:")
# print(policy_model.value_head.summary.bias)

# Optimizer
optimizer = torch.optim.AdamW(policy_model.parameters(), lr=learning_rate)

# best_loss = float('inf')
best_reward = float('-inf')

# Open a CSV file to write the losses
csv_file = open(filename+'.csv', mode='w', newline='')
csv_writer = csv.writer(csv_file)
csv_writer.writerow(['Step', 'Loss', 'Reward', 'Minireward', 'kl_divergence', 'loss_p', 'loss_v', 'ratio', 'entropy'])

num_step = 0 

@torch.no_grad()
def get_batch(model, reward_model_cur, batch_size):
    x_batch = []
    y_batch = []
    reward_batch = []
    model.eval()
    for _ in range(batch_size):
        # Calculate reward
        # Generate sequences from the LLM model using the fixed context
        generated_sequence_tensor = model.generate(context, max_new_tokens=block_size)
        # print(generated_sequence_tensor[:, :-1].shape)
        generated_sequence = decode(generated_sequence_tensor[0].tolist())  # Decode tensor to sequence
        
        # Check the generated sequence using the rule-based checkers
        violations = check_sequence(generated_sequence)
        if violations:
            rewards = torch.tensor([-1.0], device=device)  # Set reward to -1 if there are violations
        else:
            # Compute rewards for the generated sequences
            reward_logits = reward_model_cur(generated_sequence_tensor[:, :-1])
            predicted_classes = reward_logits.argmax(dim=-1)

            # Define rewards based on the predicted classes
            rewards = predicted_classes.float()

        # Collect the input data and target labels for the batch
        data = generated_sequence_tensor[0][:-1]
        label = generated_sequence_tensor[0][1:]

        # Append to batch lists
        x_batch.append(data)
        y_batch.append(label)
        reward_batch.append(rewards)

    # Stack the batches into tensors
    x_batch = torch.stack(x_batch)
    y_batch = torch.stack(y_batch)
    reward_batch = torch.stack(reward_batch)

    # print(x_batch.shape)
    # print(y_batch.shape)
    # print(reward_batch.shape)

    # Move tensors to the appropriate device
    x_batch, y_batch, reward_batch = x_batch.to(device), y_batch.to(device), reward_batch.to(device)
    model.train()
    return x_batch, y_batch, reward_batch

@torch.no_grad()
def get_minibatch(data, labels, policy_logprob, values, rewards, minibatch_size):
    """
    Selects a random minibatch of data, labels, policy log probabilities, values, and rewards.

    Args:
        data (torch.Tensor): The input data tensor of shape (batch_size, sequence_length).
        labels (torch.Tensor): The labels tensor of shape (batch_size, sequence_length).
        policy_logprob (torch.Tensor): The log probabilities of the policy of shape (batch_size, sequence_length).
        values (torch.Tensor): The value function predictions of shape (batch_size, sequence_length).
        rewards (torch.Tensor): The rewards tensor of shape (batch_size, sequence_length).
        minibatch_size (int): The size of the minibatch to sample.

    Returns:
        Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        Minibatch of (data, labels, policy_logprob, values, rewards), all with shape (minibatch_size, sequence_length).
    """
    # Generate random indices to select a minibatch
    indices = torch.randint(0, data.size(0), (minibatch_size,))

    # Select the minibatch based on the generated indices
    minibatch_data = data[indices]
    minibatch_labels = labels[indices]
    minibatch_policy_logprob = policy_logprob[indices]
    minibatch_values = values[indices]
    minibatch_rewards = rewards[indices]

    # Print the shapes of all returned tensors
    # print(f"Minibatch data shape: {minibatch_data.shape}")
    # print(f"Minibatch labels shape: {minibatch_labels.shape}")
    # print(f"Minibatch policy log probabilities shape: {minibatch_policy_logprob.shape}")
    # print(f"Minibatch values shape: {minibatch_values.shape}")
    # print(f"Minibatch rewards shape: {minibatch_rewards.shape}")

    return minibatch_data, minibatch_labels, minibatch_policy_logprob, minibatch_values, minibatch_rewards

def logprobs_from_logits(logits, labels):
    """
    Compute log probabilities of the labels given the logits.
    Args:
        logits: The logits output by the model.
        labels: The ground truth labels.
    Returns:
        Log probabilities for the correct labels.
    """
    logp = F.log_softmax(logits, dim=-1)
    # print(f"logits shape: {logits.shape}")
    # print(f"logp shape: {logp.shape}")
    # print(f"labels shape: {labels.shape}")
    logpy = torch.gather(logp, 2, labels.unsqueeze(2)).squeeze(-1)
    return logpy

def compute_rewards(scores, logprobs, ref_logprobs, beta):
    """
    Compute per token rewards from scores and KL-penalty.
    """
    rewards, non_score_rewards = [], []
    for score, logprob, ref_logprob in zip(scores, logprobs, ref_logprobs):
        # compute KL penalty (from difference in logprobs)
        kl = logprob - ref_logprob
        non_score_reward = -beta * kl
        non_score_rewards.append(non_score_reward)
        reward = non_score_reward.clone()

        # reward is preference model score + KL penalty
        reward += score
        rewards.append(reward)
    # print(torch.stack(rewards).shape)
    return torch.stack(rewards), torch.stack(non_score_rewards)

def compute_advantages(rewards, values, gamma=1, lam=0.95):
    advantages = torch.zeros_like(rewards)
    last_advantage = 0
    for t in reversed(range(rewards.size(1))):
        if t == rewards.size(1) - 1:
            next_value = 0
        else:
            next_value = values[:, t + 1]
        delta = rewards[:, t] + gamma * next_value - values[:, t]
        advantages[:, t] = last_advantage = delta + gamma * lam * last_advantage
    returns = advantages + values
    return advantages, returns

def miniloss(old_logprobs, values, rewards, logits, vpreds, logprobs, cliprange=0.2, vf_coef=0.1, gamma=1, lam=0.95):
    """
    Calculate policy and value losses.
    Args:
        old_logprobs (torch.FloatTensor): Log probabilities from the previous policy.
        values (torch.FloatTensor): Values predicted by the value head.
        rewards (torch.FloatTensor): Rewards obtained.
        logits (torch.FloatTensor): Logits from the current policy.
        vpreds (torch.FloatTensor): Value predictions from the value head.
        logprobs (torch.FloatTensor): Log probabilities from the current policy.
        cliprange (float): Clipping range for the PPO objective.
        vf_coef (float): Coefficient for the value function loss.
        gamma (float): Discount factor for rewards.
        lam (float): GAE (Generalized Advantage Estimation) lambda.

    Returns:
        torch.Tensor: Total loss combining policy and value losses.
    """
    # Compute the advantages using GAE
    advantages, returns = compute_advantages(rewards, values.detach(), gamma=gamma, lam=lam)

    # Clipped value loss
    vpredclipped = torch.clamp(vpreds, values - cliprange, values + cliprange)
    vf_losses1 = (vpreds - returns) ** 2
    vf_losses2 = (vpredclipped - returns) ** 2
    vf_loss = 0.5 * torch.max(vf_losses1, vf_losses2).mean()

    # Policy loss (clipped PPO objective)
    ratio = torch.exp(logprobs - old_logprobs)
    pg_losses = -advantages * ratio
    pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - cliprange, 1.0 + cliprange)
    pg_loss = torch.max(pg_losses, pg_losses2).mean()

    # Entropy for exploration (not used in this simple version, but you could add it)
    # entropy = -torch.sum(logits * logprobs, dim=-1).mean()

    # Combine policy and value losses
    # total_loss = pg_loss + vf_coef * vf_loss

    return pg_loss, vf_coef * vf_loss, ratio

def entropy_from_logits(logits):
    """Calculate entropy from logits."""
    pd = torch.nn.functional.softmax(logits, dim=-1)
    entropy = torch.logsumexp(logits, axis=-1) - torch.sum(pd * logits, axis=-1)
    return entropy

for epoch in range(epochs):
    policy_model.train()
    baseline_model.eval()  # Baseline model is not trained, it's used for comparison
    reward_model.eval()
    
    # Get a batch of input data and corresponding rewards
    xb, yb, scores = get_batch(policy_model, reward_model, batch_size)

    # Forward pass through the policy model and the baseline model
    with torch.no_grad():
        logits_policy, values = policy_model(xb, yb)
        logits_baseline, _ = baseline_model(xb, yb)
        # Reshape logits to correct shape
        # logits_policy = logits_policy.view(batch_size, block_size, vocab_size)
        # logits_baseline = logits_baseline.view(batch_size, block_size, vocab_size)  

        # Compute log probabilities for the policy model using the labels
        log_probs_policy = logprobs_from_logits(logits_policy, yb)
        log_probs_ref = logprobs_from_logits(logits_baseline, yb)

    rewards, non_score_rewards = compute_rewards(scores, log_probs_policy, log_probs_ref, beta)

    for _ in range(ppo_epochs):
        minibatch_data, minibatch_labels, minibatch_policy_logprob, minibatch_values, minibatch_rewards = get_minibatch(xb, yb, log_probs_policy, values, rewards, minibatch_size)
        mini_logits_policy, vpredict = policy_model(minibatch_data, minibatch_labels)
        mini_logits_baseline, _= baseline_model(minibatch_data, minibatch_labels)
        
        # Reshape logits to correct shape
        # mini_logits_policy = mini_logits_policy.view(minibatch_size, block_size, vocab_size)
        # mini_logits_baseline = mini_logits_baseline.view(minibatch_size, block_size, vocab_size)  

        # Compute log probabilities for the policy model using the labels
        minibatch_policy_logprob_new = logprobs_from_logits(mini_logits_policy, minibatch_labels)
        mini_log_probs_ref = logprobs_from_logits(mini_logits_baseline, minibatch_labels)

        loss_p, loss_v, ratio = miniloss(minibatch_policy_logprob, minibatch_values, minibatch_rewards, mini_logits_policy, vpredict, minibatch_policy_logprob_new)
        loss = loss_p + loss_v

        avg_ratio = ratio.mean().item()
        if avg_ratio > ratio_threshold:
            print(f"The average ratio of batch ({avg_ratio:.2f}) exceeds threshold {ratio_threshold:.2f}. Skipping batch.")
            loss_p = loss_p * 0.0
            loss_v = loss_v * 0.0
            loss = loss * 0.0
    
        # Backpropagation and optimization step
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        # kl_divergence = (mini_log_probs_policy.exp() * (mini_log_probs_policy - mini_log_probs_ref)).sum(dim=-1)
        # kl_divergence_mean = kl_divergence.mean()

        # Assuming mini_logits_policy and mini_logits_baseline are logits
        # Convert logits to log probabilities
        log_probs_policyforkl = torch.log_softmax(mini_logits_policy, dim=-1)
        log_probs_baselineforkl = torch.log_softmax(mini_logits_baseline, dim=-1)

        # Compute probabilities from log probabilities
        probs_policy = log_probs_policyforkl.exp()

        # Compute the KL divergence directly using log probabilities
        kl_divergence = (probs_policy * (log_probs_policyforkl - log_probs_baselineforkl)).sum(dim=-1)

        # To get a single value for KL divergence across the entire minibatch, take the mean over the minibatch and block_size dimensions
        kl_divergence_mean = kl_divergence.mean()

        entropy = entropy_from_logits(mini_logits_policy)
        # # Print and log the results
        print(f"Step {num_step}, Loss: {loss.item():.4f}, Reward: {rewards.mean().item():.4f}, Minireward: {minibatch_rewards.mean().item():.4f}, kl_divergence: {kl_divergence_mean.item():.4f}, loss_p: {loss_p.item():.4f}, loss_v: {loss_v.item():.4f}, ratio: {ratio.mean().item():.4f}, entropy: {entropy.mean().item():.4f}")
        csv_writer.writerow([num_step, loss.item(), rewards.mean().item(), minibatch_rewards.mean().item(), kl_divergence_mean.item(), loss_p.item(), loss_v.item(), ratio.mean().item(), entropy.mean().item()])  # Write losses to CSV
        # if loss.item() < best_loss:
        if rewards.mean().item() > best_reward:
            savemodel_name = filename + '.pth'
            torch.save(policy_model.state_dict(), savemodel_name)
            # best_loss = loss.item()
            best_reward = rewards.mean().item()

        num_step += 1