import torch
import json
import torch.distributed
from tqdm import tqdm

from megatron.core import parallel_state, tensor_parallel
from megatron.core.pipeline_parallel.p2p_communication import recv_forward, send_forward

from megatron.training import get_args, get_tokenizer, get_model, print_rank_0, is_last_rank
from megatron.training.checkpointing import load_checkpoint
from megatron.training.arguments import core_transformer_config_from_args
from megatron.training.utils import get_ltor_masks_and_position_ids, unwrap_model

from tasks.qa_zeroshot_gpt.utils import model_provider, track_eval_moe_metrics, clear_aux_losses_tracker
from tasks.qa_zeroshot_gpt.datasets import build_data_loader, build_dataset
from tasks.qa_zeroshot_gpt.param_calc import *

GLOBAL_METRICS = []
MICRO_BATCH_SIZE = 0

def process_batch(batch):
    """Process batch and produce inputs for the model."""
    args = get_args()
    tokenizer = get_tokenizer()
    
    qa_idx = batch['qa_sample_id'].long().cuda().contiguous()
    qa_label = batch['qa_label'].long().cuda().contiguous()

    loss_mask = batch['loss_mask'].long().cuda().contiguous().byte()
    ae_mask = batch['ae_mask'].long().cuda().contiguous().byte()
    tokens_ = batch['text'].long().cuda().contiguous()
    labels = tokens_[:, 1:].contiguous()
    tokens = tokens_[:, :-1].contiguous()

    # Get the masks and postition ids.
    attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
        tokens,
        tokenizer.eod,
        args.reset_position_ids,
        args.reset_attention_mask,
        args.eod_mask_loss)

    return tokens, labels, attention_mask, position_ids, loss_mask, ae_mask, qa_idx, qa_label

def forward_step(batch, model, config, total_output):
    """Forward step."""

    # Get the batch.
    tokens, labels, attention_mask, position_ids, loss_mask, ae_mask, qa_idxs, qa_labels = process_batch(
        batch)
    
    # print('ae_mask Shape:', ae_mask.shape, 'Sum:', ae_mask.sum())

    # Tell the model what our actual batch size will be
    args = get_args()
    args.micro_batch_size = len(labels)

    tensor_shape = (args.inference_max_seq_length, args.micro_batch_size, args.hidden_size)
    input_tensor = recv_forward(tensor_shape, config)

    # Forward pass through the model.
    unwrapped_model = unwrap_model(model)
    unwrapped_model.set_input_tensor(input_tensor)
    output = model(tokens, position_ids, attention_mask, ae_mask=ae_mask)

    send_forward(output, config)

    if parallel_state.is_pipeline_last_stage():
        losses = tensor_parallel.vocab_parallel_cross_entropy(
                output.contiguous().float(), labels.contiguous())
        
        loss_sums = torch.sum(
                losses * loss_mask.contiguous().float(), dim=1) 

        token_counts = loss_mask.sum(dim=1)

        per_sample_loss = loss_sums / token_counts.clamp(min=1)
    
        for idx in range(len(qa_idxs)):
            qa_idx = qa_idxs[idx]
            total_output[qa_idx][0] = per_sample_loss[idx]
            total_output[qa_idx][1] = qa_labels[idx]

def process_output(total_output, answers_per_questions):
    # Make sure the number of rows is even (pairs of options per sample)
    assert total_output.size(0) % answers_per_questions == 0, "Expected even number of rows in total_output"

    # Reshape total_output to [num_samples, answers_per_questions, 2]
    # Shape: [batch_size, answers_per_questions, [score, label]]    
    paired = total_output.view(-1, answers_per_questions, 2)

    # Extract scores for both options (dim=2, index 0)
    scores = paired[:, :, 0]  # shape: [num_samples, 2]

    # Get predicted option via argmax over scores
    preds = torch.argmin(scores, dim=1)  # shape: [num_samples]

    # True labels: same in both rows, take label from any (e.g., index 0)
    labels = paired[:, 0, 1].long()  # shape: [num_samples]

    # Compare predictions to labels
    correct_mask = preds == labels
    correct = correct_mask.sum().item()
    
    # Gather scores for predicted answers
    pred_scores = scores[torch.arange(scores.size(0)), preds]

    # Average score of correct predictions
    if correct > 0:
        avg_correct_score = pred_scores[correct_mask].mean().item()
    else:
        avg_correct_score = 0.0

    return correct, avg_correct_score

def evaluate(task, data_loader, model):
    """Evaluation."""
    args = get_args()
    config = core_transformer_config_from_args(args)
    
    # Turn on evaluation mode which disables dropout.
    model.eval()

    print_rank_0(f'> amount of iterations: {len(data_loader)}')

    answers_per_questions = data_loader.dataset.get_answers_per_questions()
    total_output = torch.zeros(len(data_loader.dataset), 2).cuda()
    
    with torch.no_grad():
        # For all the batches in the dataset.
        for iteration, batch in enumerate(data_loader, start=1):
            # Forward evaluation.
            output = torch.zeros_like(total_output)
            forward_step(batch, model, config, output)

            # Reduce across processes.
            if parallel_state.is_pipeline_last_stage():
                torch.distributed.all_reduce(output,
                                             group=parallel_state.get_data_parallel_group(),
                                             op=torch.distributed.ReduceOp.SUM)
                
                total_output += output
            
    avg_active_experts = track_eval_moe_metrics(task, len(data_loader)) # potentialy args.log_interval

    accuracy_output, avg_correct_score = process_output(total_output, answers_per_questions)

    return accuracy_output, avg_correct_score, avg_active_experts

def evaluate_and_print_results(task, data_loader, model, scores_th):
    """Evaluate and print results on screen."""

    global GLOBAL_METRICS

    # Evaluate and get results.
    accuracy_output, avg_correct_score, avg_active_experts = evaluate(task, data_loader, model)

    string = ' validation results on {} | '.format(task)
    if is_last_rank():

        num_examples = len(data_loader.dataset) / data_loader.dataset.get_answers_per_questions()
        acc = (accuracy_output / num_examples) * 100
        string += f'number correct: {int(accuracy_output)} | '
        string += f'total examples: {int(num_examples)} | '
        string += f'avg accuracy: {acc:.4f}% | '
        string += f'avg loss: {avg_correct_score:.4f}'

        length = len(string) + 1
        print('-' * length)
        print(string)
        print('-' * length)

        GLOBAL_METRICS.append({
            "task": task,
            "num_correct": int(accuracy_output),
            "num_examples": int(num_examples),
            "avg_accuracy_pct": acc,
            "avg_loss": avg_correct_score,
            "avg_active_experts": avg_active_experts,
            "scores_th": scores_th
        })

def all_datasets_eval(args, model, valid_data_copy, scores_th = 0.0):
    print_rank_0('preparing for all tasks')

    for task, data_path in zip(['PIQA', 'HELLASWAG', 'ARC-E'], valid_data_copy):
        print_rank_0('building Dataset')
        global MICRO_BATCH_SIZE 
        args.micro_batch_size = MICRO_BATCH_SIZE
        args.valid_data = [data_path] #NOT THE BEST BUT CAN BE USED WITH REST OF CODE

        dataset = build_dataset(task)
        dataloader = build_data_loader(dataset, args.micro_batch_size, args.num_workers)

        # Run evaluation.
        print_rank_0(f'start {task} evaluation')
        
        evaluate_and_print_results(task, dataloader, model, scores_th)

        print_rank_0(f'{task} done :-)')

def main():
    """Main program."""
    args = get_args()
    global MICRO_BATCH_SIZE 
    MICRO_BATCH_SIZE = args.micro_batch_size

    if args.num_layers_per_virtual_pipeline_stage is not None:
        print("Interleaved pipeline schedule is not yet supported for text generation.")
        exit()

    assert len(args.valid_data) == 3, "There should be paths to all 3 datasets in order PIQA, HELLASWAG AND ARC-E"
    valid_data_copy = args.valid_data

    # Set up model and load checkpoint.
    model = get_model(model_provider, wrap_with_ddp=False)
    if args.load is not None:
        _ = load_checkpoint(model, None, None)

    assert len(model) == 1, "Above condition should have caught this"
    model = model[0]

    report = collect_empirical_report(model)
    print_model_report(report)
    report2 = collect_analytical_report(args)
    print_model_report(report2)

    if args.task == 'ALL-QA':
        all_datasets_eval(args, model, valid_data_copy)
    
    elif args.task == 'ALL-QA-MULTI-TH':
        for scores_th in args.multi_scores_th:
            print(f'Setup coeff to {scores_th}')

            for layer in model.module.decoder.layers:
                if getattr(layer, "mlp", None) and getattr(layer.mlp, "router", None):
                    layer.mlp.router.config.moe_router_soft_topk_routing_scores_threshold = float(scores_th)

            coeffs = [layer.mlp.router.config.moe_router_soft_topk_routing_scores_threshold 
                                               for layer in model.module.decoder.layers 
                                               if getattr(layer, "mlp", None) and 
                                               getattr(layer.mlp, "router", None)]

            print(f'Coeffs from configs are now {coeffs}')

            all_datasets_eval(args, model, valid_data_copy, scores_th)

            print('\n\n\n')
        
        with open('./logs/metrics.json', "w") as f:
            global GLOBAL_METRICS
            json.dump(GLOBAL_METRICS, f, indent=2)

    else:
        # Data stuff.
        print_rank_0('building Dataset')
        dataset = build_dataset(args.task)
        dataloader = build_data_loader(dataset, args.micro_batch_size, args.num_workers)

        # Run evaluation.
        print_rank_0('start evaluation')
        evaluate_and_print_results(args.task, dataloader, model)

        print_rank_0('done :-)')