from typing import List, Dict, Any, Tuple, Callable
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
from collections import defaultdict
from tqdm import tqdm
import wandb
from callbacks.metrics import METRICS_FUNCTIONS
from callbacks.metrics import compute_loss_and_accuracy, extract_response_logprobs, per_token_probs

class Evaluator:
    def __init__(
        self,
        model: nn.Module,
        tokenizer: Any,
        eval_dataloaders: List[DataLoader],
        loss_fn: Callable,
        dataset_type: str = 'star', 
        use_tqdm: bool = False,
        cfg: Dict[str, Any] = None
    ):
        self.model = model
        self.tokenizer = tokenizer
        self.eval_dataloaders = eval_dataloaders
        self.loss_fn = loss_fn
        self.n_responses = cfg.eval.n_responses
        max_repeats = 2048 // cfg.eval.prompt_batch_size
        self.num_paths = [min(self.eval_dataloaders[i].dataset.data[0]['num_paths'], max_repeats) for i in range(len(self.eval_dataloaders))]
        self.n_table = cfg.eval.n_table
        self.use_wandb = cfg.use_wandb
        self.dataset_type = dataset_type
        self.prompt_batch_size = cfg.eval.prompt_batch_size
        self.max_new_tokens = [dl.dataset.target_max_length for dl in eval_dataloaders]
        self.device = next(model.parameters()).device
        if dataset_type not in METRICS_FUNCTIONS:
            raise ValueError(f"Unknown dataset type: {dataset_type}. Must be one of {list(METRICS_FUNCTIONS.keys())}")
        self.generation_metrics_fn = METRICS_FUNCTIONS[dataset_type]
        self.path_lengths = [param.layers + 2 for param in cfg.data.params]
        self.use_tqdm = use_tqdm
    
    def generations_and_metrics(
        self,
        batch: Dict[str, torch.Tensor],
        dataset_samples: List[Dict[str, Any]]
    ) -> Dict[str, Any]:
        batch = {k: v.to(self.device) if torch.is_tensor(v) else v for k, v in batch.items()}
        input_ids = batch["input_ids"][::self.num_paths]
        attention_mask = batch["attention_mask"][::self.num_paths]
        delimiter_mask = (input_ids == self.tokenizer.delimiter_id)
        delimiter_pos = delimiter_mask.nonzero(as_tuple=True)[1]
        if len(delimiter_pos) != len(input_ids):
            raise ValueError(f"Expected exactly one delimiter token per sequence, got {len(delimiter_pos)} for batch size {len(input_ids)}")
        input_ids = torch.stack([ids[:pos+1] for ids, pos in zip(input_ids, delimiter_pos)])
        attention_mask = torch.stack([mask[:pos+1] for mask, pos in zip(attention_mask, delimiter_pos)])
        outputs_and_metrics = {'metrics': {}}
        for name, temperature in zip(['temp'], [1.0]):
            responses, logprobs = self.model.batch_generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                max_new_tokens=self.max_new_tokens,
                num_return_sequences=1 if name == 'greedy' else self.n_responses,
                temperature=temperature
            )
            per_sample_metrics = self.generation_metrics_fn(batch, dataset_samples, responses, self.tokenizer)
            aggregated_metrics = {}
            aggregated_metrics['avg_logprobs'] = torch.sum(logprobs, dim=2).mean().item()
            for metric_name, metric_array in per_sample_metrics.items():
                if metric_name == 'node_correct':
                    position_accuracies = np.mean(metric_array, axis=(0, 1))
                    aggregated_metrics["match_per_token"] = [float(acc) for acc in position_accuracies]
                else:
                    sample_means = np.mean(metric_array, axis=1)
                    aggregated_metrics[f"avg_{metric_name}"] = float(np.mean(sample_means))
                    aggregated_metrics[f"avg_exists_{metric_name}"] = float(np.mean(np.any(metric_array, axis=1)))
            outputs_and_metrics['metrics'][name] = aggregated_metrics
            outputs_and_metrics[name] = {
                'outputs': responses,
                'logprobs': logprobs,
            }
        return outputs_and_metrics

    def _loop(
            self,
            dataloader_idx: int,
            compute_logprobs: bool,
            compute_generations: bool,
            compute_accuracies: bool,
            ) -> Tuple[Dict[str, Any], Dict[str, float]]:
        self.model.eval()
        dataloader = self.eval_dataloaders[dataloader_idx]
        num_paths = self.num_paths[dataloader_idx]
        path_length = self.path_lengths[dataloader_idx]
        iterator = tqdm(dataloader, desc="Evaluating") if self.use_tqdm else dataloader
        running_metrics = defaultdict(list)
        running_outputs = defaultdict(lambda: defaultdict(list))
        vis_data = None
        running_logprobs = [] if compute_logprobs else None
        
        with torch.inference_mode():
            for batch_idx, batch in enumerate(iterator):
                dataset_idxs = batch['idx'][::num_paths]
                dataset_samples = [dataloader.dataset.data[i] for i in dataset_idxs]
                batch = {k: v.to(self.device) if torch.is_tensor(v) else v for k, v in batch.items()}
                unique_input_ids = batch["input_ids"][::num_paths]
                unique_attention_mask = batch["attention_mask"][::num_paths]
                unique_labels = batch["labels"][::num_paths]
                unique_logits = self.model(
                    input_ids=unique_input_ids,
                    attention_mask=unique_attention_mask
                )
                logits = unique_logits.repeat_interleave(num_paths, dim=0)
                labels = batch["labels"]
                logprobs = extract_response_logprobs(logits, labels)
                logprobs = logprobs.reshape(self.prompt_batch_size, num_paths, -1)
                if compute_logprobs:
                    running_logprobs.append(logprobs)
                loss_and_accuracy_metrics = compute_loss_and_accuracy(logits, labels, self.loss_fn, path_length)
                for k, v in loss_and_accuracy_metrics.items():
                    running_metrics[k].append(v)
                if compute_accuracies:
                    ntp_probs = per_token_probs(logprobs)
                    for k, v in ntp_probs.items():
                        running_metrics[k].append(v)
                if compute_generations:
                    outputs_and_metrics = self.generations_and_metrics(batch, dataset_samples)
                    gen_metrics = outputs_and_metrics['metrics']
                    flattened_gen_metrics = {f"{k}/{v}": gen_metrics[k][v] for k, k_m in gen_metrics.items() for v in k_m.keys()}
                    for k, v in flattened_gen_metrics.items():
                        running_metrics[k].append(v)
                    for name, dict in outputs_and_metrics.items():  
                        if name not in ['metrics']:
                            for k, v in dict.items():
                                running_outputs[name][k].append(v)
                if self.use_tqdm:
                    iterator.set_postfix(loss=f"{running_metrics['loss'][-1]:.4f}")
        
        final_metrics = {k: float(np.mean(v)) for k, v in running_metrics.items()}
        final_outputs = {name: {k: torch.cat(v, dim=0).cpu().numpy() for k,v in dict.items()} for name, dict in running_outputs.items()}
        if running_logprobs:
            final_logprobs = torch.cat(running_logprobs, dim=0).cpu().numpy()
        else:
            final_logprobs = running_logprobs
        return final_outputs, final_metrics, vis_data, final_logprobs
    
    def evaluate(
            self,
            iters: int,
            compute_logprobs: bool,
            compute_generations: bool,
            compute_accuracies: bool,
    ) -> Dict[str, Any]:
        val_results = []
        for idx, dataloader in enumerate(self.eval_dataloaders):
            outputs, metrics, vis_data, logprobs = self._loop(
                idx,
                compute_logprobs=compute_logprobs,
                compute_generations=compute_generations,
                compute_accuracies=compute_accuracies
            )
            val_results.append({
                'outputs': outputs,
                'metrics': metrics,
                'vis_data': vis_data,
                'logprobs': logprobs
            })
        metrics = {}
        for i, res in enumerate(val_results):
            metrics.update({f'eval_{i}/{k}': v for k, v in res['metrics'].items()})
        if self.use_wandb and wandb.run:
            wandb.log({k: v for k, v in metrics.items() if ('node_correct' not in k and 'ntp_correct_' not in k and 'ntp_probs_' not in k)}, step=iters)
        val_outputs = [res['outputs'] for res in val_results]
        val_logprobs = [res['logprobs'] for res in val_results]
        return val_outputs, val_logprobs