import os
import torch
import torch.nn as nn
import torch.optim as optim
import random
from collections import deque
from IPython import embed
from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm
from evaluator import evaluate_model  # 👈 确保导入 evaluate_model
            
class QLearningAgent:
    def __init__(self, model, device, lr=0, gamma=0):
        self.model = model.to(device)
        self.optimizer = optim.Adam(model.parameters(), lr=lr)
        self.gamma = gamma
        self.device = device
        self.memory = deque(maxlen=10000)
        self.batch_size = 512
        self.pad_token = 0
        
    def get_action(self, state, epsilon, available_nodes):
        if random.random() < epsilon:
            indices = available_nodes.nonzero().squeeze()
            if indices.ndim == 0:
                random_idx = indices.item()
            else:
                random_idx = indices[torch.randint(len(indices), (1,))].item()
            return random_idx
        else:
            with torch.no_grad():
                state_tensor = torch.LongTensor(state).unsqueeze(0).to(self.device)
                
                q_values = self.model(state_tensor)[0][0, -1]  # Get Q-values for last position
                mask = torch.ones_like(q_values) * -float('inf')

                mask[2:][available_nodes == 1] = 0
                q_values = q_values + mask
                
                return torch.argmax(q_values).item()

    def get_action_soft(self, state, epsilon, available_nodes):
        if random.random() < epsilon:
            indices = available_nodes.nonzero().squeeze()
            if indices.ndim == 0:
                random_idx = indices.item()
            else:
                random_idx = indices[torch.randint(len(indices), (1,))].item()
            return random_idx
        else:
            with torch.no_grad():
                state_tensor = torch.LongTensor(state).unsqueeze(0).to(self.device)
                
                q_values = self.model(state_tensor)[0][0, -1]  # Get Q-values for last position
                mask = torch.ones_like(q_values) * -float('inf')

                mask[2:][available_nodes == 1] = 0
                q_values = q_values + mask
                
                probabilities = torch.softmax(q_values, dim=0)
                sampled_index = torch.multinomial(probabilities, num_samples=1).item()
                return sampled_index
            
    def get_action_test(self, state, available_nodes):
        with torch.no_grad():
            state_tensor = torch.LongTensor(state).unsqueeze(0).to(self.device)
            
            q_values = self.model(state_tensor)[0][0, -1]  # Get Q-values for last position
            mask = torch.ones_like(q_values) * -float('inf')
            mask[available_nodes] = 0
            q_values = q_values + mask
            
            return torch.argmax(q_values).item()
             
    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))
    
    def replay(self):
        if len(self.memory) < self.batch_size:
            return
        
        batch = random.sample(self.memory, self.batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        self.optimizer.zero_grad()

        states_padded = pad_sequence(
            [torch.LongTensor(s) for s in states],
            batch_first=True,
            padding_value=self.pad_token
        ).to(self.device)
        
        next_states_padded = pad_sequence(
            [torch.LongTensor(s) for s in next_states],
            batch_first=True,
            padding_value=self.pad_token
        ).to(self.device)
        
        actions = torch.LongTensor(actions).to(self.device)
        rewards = torch.FloatTensor(rewards).to(self.device)
        dones = torch.FloatTensor(dones).to(self.device)

        current_q = self.model(states_padded)[0][:, -1].gather(1, actions.unsqueeze(1))
        next_q = self.model(next_states_padded)[0][:, -1].max(1)[0].detach()
        target = rewards + (1 - dones) * self.gamma * next_q

        loss = nn.MSELoss()(current_q.squeeze(), target)
        
        loss.backward()
        self.optimizer.step()
        return loss.item()

# <<< MODIFICATION START >>>
def calculate_q_stats(agent, sample_pairs, device, sample_size=1000):
    """
    Calculates the mean and std dev of initial max Q-values.
    """
    print("Calculating initial Q-value statistics for reward transformation...")
    agent.model.eval()
    max_q_values = []
    min_q_values = []
    
    num_samples = min(sample_size, len(sample_pairs))
    
    with torch.no_grad():
        for _ in tqdm(range(num_samples), desc="Sampling Q-values"):
            s, t = random.choice(sample_pairs)
            state = [s, t, s]
            state_tensor = torch.LongTensor(state).unsqueeze(0).to(device)
            q_values = agent.model(state_tensor)[0][0, -1]
            max_q_values.append(q_values.max().item())
            min_q_values.append(q_values.min().item())

    agent.model.train()
    
    if len(max_q_values) < 2: # 需要至少2个点才能计算标准差
        return 0.0, 1.0 # 返回一个不会改变奖励的默认值

    max_q_values_tensor = torch.tensor(max_q_values)
    min_q_values_tensor = torch.tensor(min_q_values)

    q_mean = max_q_values_tensor.mean().item()
    q_mean_min = min_q_values_tensor.mean().item()
    q_std = max_q_values_tensor.std().item()

    # 防止标准差为0
    if q_std < 1e-6:
        q_std = 1.0

    print(f"✅ Calculated Q-stats: mean={q_mean:.4f}, mean-min={q_mean_min:.4f}, std={q_std:.4f}")
    return q_mean, q_mean_min, q_std
# <<< MODIFICATION END >>>


def train_rl(agent, train_pairs, true_adj, true_reach, max_steps=60, out_dir=None, model_args=None, config=None, soft=False, args=None):
    if args is None:
        raise ValueError("args must be provided for evaluation during training")

    correct_cnt = 0
    
    cur_epsilon = args.epsilon
    epsilon_lr = args.epsilon / args.num_episodes
    DEVICE = agent.device
    true_adj = torch.tensor(true_adj).to(DEVICE)
    
    print(true_adj)
    q_mean_max, q_mean_min, q_std = calculate_q_stats(agent, train_pairs, DEVICE)
    q_mean_max = max(1, q_mean_max)
    
    # 🎯 可选：创建评估日志文件
    log_file = os.path.join(out_dir, "eval_log.txt") if out_dir else None

    for episode in tqdm(range(args.num_episodes+1)):
        s, t = random.choice(train_pairs)
        state = [s, t, s]
        path = [s.to(torch.long).cpu().numpy()]
        correct = True
        for _ in range(max_steps):
            current_node = path[-1]
            available_nodes = torch.ones(true_reach.shape[0]).to(DEVICE)
            
            if soft:
                next_node = agent.get_action_soft(state, epsilon=cur_epsilon, available_nodes=available_nodes)
            else:
                next_node = agent.get_action(state, epsilon=cur_epsilon, available_nodes=available_nodes) 
 
            next_state = state + [next_node]
            path.append(next_node)

            if args.reward_type == 'step':
                reward = 0
                if true_adj[current_node - 2][next_node - 2] == 0:
                    reward = -1
                    correct = False
                    done = False
                if next_node == t:
                    reward += 1
                    correct_cnt += 1 if correct else 0
                    done = True
                elif next_node > t:
                    done = True
                else:
                    done = False
            elif args.reward_type == 'final':
                reward = 0
                if true_adj[current_node - 2][next_node - 2] == 0:
                    correct = False
                    done = False
                if next_node == t:
                    reward += 100 if correct else 0
                    correct_cnt += 1 if correct else 0
                    done = True
                elif next_node > t:
                    done = True
                else:
                    done = False
            else:
                raise KeyError
            
            # scaled_reward = reward * q_mean_max * 2
            scaled_reward = reward * 100
            agent.remember(state, next_node, scaled_reward, next_state, done)
            
            state = next_state
            if done:
                break
            
        loss = agent.replay()
        cur_epsilon = max(0, cur_epsilon - epsilon_lr)  # 避免负数

        # 🧪 每 100 episode 进行一次评估
        if episode % args.eval_interval == 0:
            class EvalArgs:
                def __init__(self, **kwargs):
                    self.__dict__.update(kwargs)
                    
            if args.train_type == 'simple':
                datas = ['simple_train', 'simple_test']
            elif args.train_type == 'aug':
                datas = ['train2train', 'train2test', 'test2train', 'test2test']
            else:
                raise KeyError
            
            for eval_type_data in datas:
                eval_args = EvalArgs(
                    dataset=args.dataset,
                    data_dir=f"{args.num_nodes}_1_1",
                    type_data=eval_type_data,
                    test_num=None,
                    batch_size=args.eval_batch_size,
                    temperature=args.eval_temperature,
                    write_result=False,
                    fix_att=args.fix_att,
                    result_name=None,
                    out_dir=out_dir,
                )

                accuracy = evaluate_model(agent.model, eval_args, args.device, log_file=log_file, step=episode)
                print(f"[{episode}] Eval Accuracy on {eval_type_data}: {accuracy:.4f}")
                with open(log_file, 'a') as f:
                    f.write(f"[{episode}] Eval Accuracy on {eval_type_data}: {accuracy:.4f}\n")

        # 💾 每 1000 episode 保存 checkpoint
        if episode % 5000 == 0:
            print(f"Episode {episode}, Loss: {loss if loss else 0}")
            checkpoint = {
                'model': agent.model.state_dict(),
                'optimizer': agent.optimizer.state_dict(),
                'model_args': model_args,
                'iter_num': episode,
                'best_val_loss': loss,
                'config': config,
            }
            print(f"saving checkpoint to {out_dir}")
            torch.save(checkpoint, os.path.join(out_dir, f'{episode}_ckpt.pt'))
        
      
def test_rl(agent, test_pairs, true_adj, true_reach, path_file_path):
    correct_cnt = 0
    DEVICE = agent.device
    true_adj = torch.tensor(true_adj).to(DEVICE)

    for i in tqdm(range(len(test_pairs))):
        # Select a random start-end pair
        s, t = test_pairs[i][0], test_pairs[i][1]
        state = [s, t, s]  # Initial state is [start, target]
        path = [s.to(torch.long).cpu().numpy()]

        feedback = None
        for _ in range(100):
            current_node = path[-1]
            available_nodes = torch.ones(true_reach.shape[0]).to(DEVICE)
            next_node = agent.get_action(state, epsilon=0.0,  available_nodes=available_nodes) # epsilon 0.2
            next_state = state + [next_node]
            path.append(next_node)
            
            if true_adj[current_node - 2][next_node - 2] == 0:
                feedback = 'Not in adj'
                break
            else:
                if next_node == t:
                    correct_cnt += 1
                    feedback = 'Correct'
                    break
                elif true_reach[t - 2][next_node - 2] == 0:
                    feedback = 'Not in reach'
                    break
            state = next_state
            
        with open(path_file_path, 'a') as f:
            f.write(f'{s} {t} ')
            for node in path:
                f.write(f'{node} ')
            f.write(feedback)
            f.write('\n')

    print(f"Test Accuracy {correct_cnt / len(test_pairs)}")
    return correct_cnt / len(test_pairs)
