from typing import Dict, Any, List, Tuple, Callable
import torch
import numpy as np
from collections import defaultdict
from data.graph_dataset import path_to_str


def is_valid_path(data, path_list):
    path_length = data['path_length']
    path_nodes = data['policy_nodes']
    node_correct = [False] * path_length
    
    try:
        nodes = [int(x.strip()) for x in path_list]
    except ValueError:
        return False, node_correct
    
    if len(nodes) != path_length:
        return False, node_correct
    
    node_correct = [node in path_nodes[i] for i, node in enumerate(nodes)]
    return all(node_correct), node_correct


def is_correct_path(data, path_list):
    try:
        nodes = [int(x.strip()) for x in path_list]
    except ValueError:
        return False
    
    target_path = data['paths']
    return nodes == target_path

def star_generation_metrics(
    batch: Dict[str, torch.Tensor],
    dataset_samples: List[Dict[str, Any]],
    generated_outputs: torch.Tensor,  
    tokenizer: Any
) -> Dict[str, np.ndarray]:

    B = generated_outputs.shape[0]
    n_responses = generated_outputs.shape[1]
    
    per_sample_metrics = defaultdict(lambda: defaultdict(list))
    for i, sample in enumerate(dataset_samples):
        for j in range(n_responses):
           
            pred_tokens = generated_outputs[i, j].cpu().tolist()
            pred_path = tokenizer.decode(pred_tokens)
            valid_path, node_correct = is_valid_path(sample, pred_path)
            
            per_sample_metrics['valid'][i].append(valid_path)
            for j, node in enumerate(node_correct):
                per_sample_metrics[f'node_correct_{j}'][i].append(node)

    metrics_arrays = {metric_name: np.array([v for v in values.values()]) 
                     for metric_name, values in per_sample_metrics.items()}
    return metrics_arrays

def extract_response_logprobs(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
  
    mask = labels != -100 
    B = mask.shape[0]


    sequence_logprobs = torch.stack([
        torch.log_softmax(logits[i, :, mask[i]], dim=0)
        .gather(dim=0, index=labels[i, mask[i]].unsqueeze(0))
        for i in range(B)
    ])
    return sequence_logprobs.squeeze(1)

def compute_loss_and_accuracy(logits: torch.Tensor, labels: torch.Tensor, loss_fn: Callable, path_length: int) -> Dict[str, float]:
    loss = loss_fn(logits, labels)
    return {
        'loss': path_length * loss.item(),
    }

def is_ntp_valid(data, path, pad_total_length=None):
    path_nodes = data['policy_nodes']
    is_integer = [True if str(x).isdigit() else False for x in path]
    corrects = [True if is_integer[i] and int(node) in path_nodes[i] else False for i, node in enumerate(path)]
   
    if pad_total_length is not None:
        corrects = corrects + [False] * (pad_total_length - len(corrects))
    return corrects

    
def compute_ntp_validity(
    logits: torch.Tensor,
    labels: torch.Tensor,
    dataset_samples: List[Dict[str, Any]],
    tokenizer: Any
) -> Dict[str, float]:
    
    B = logits.shape[0]
    V = logits.shape[1]
    mask = labels != -100
    
    pred_tokens = logits.argmax(dim=1) 
    
    response_tokens = [pred_tokens[i][mask[i]].cpu().tolist() for i in range(B)]
    response_paths = [tokenizer.decode(tokens) for tokens in response_tokens]

    path_lengths = [len(p) for p in response_paths]
    path_corrects = [is_ntp_valid(data, path, pad_total_length=max(path_lengths)) for data, path in zip(dataset_samples, response_paths)]


    per_node_corrects = np.mean(np.array(path_corrects).astype(float), axis=0)
    avg_correct = np.mean(per_node_corrects)

    return {**{ f"ntp_correct_{i}": correct for i, correct in enumerate(per_node_corrects)}, "avg_ntp_correct": avg_correct }


def per_token_probs(
        logprobs: torch.Tensor,
):
    probs = torch.exp(logprobs.flatten(0, 1))
    per_token_probs = torch.mean(probs, dim=0) 
    return {**{f"ntp_probs_{i}": per_token_probs[i].item() for i in range(len(per_token_probs))},
            "avg_ntp_probs": torch.mean(per_token_probs).item()}
    

METRICS_FUNCTIONS = {
    'star': star_generation_metrics
} 