import torch
from moelayer import MoELoRA, MoELoRAQwen,WrappedLoRALayer
from transformers import T5Tokenizer, T5ForConditionalGeneration
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from datasets import load_from_disk, concatenate_datasets
from torch.utils.data import DataLoader
from sklearn.metrics import matthews_corrcoef, accuracy_score, f1_score
from scipy.stats import pearsonr

def load_and_concatenate_datasets(paths):
    datasets = []
    for path in paths:
        try:
            ds = load_from_disk(path)
            print(f"Loaded dataset from {path} with {len(ds)} samples")
            datasets.append(ds)
        except Exception as e:
            print(f"Error loading {path}: {str(e)}")
            continue
    
    if not datasets:
        raise ValueError("No valid datasets loaded")
    
    return concatenate_datasets(datasets) if len(datasets) > 1 else datasets[0]


valid_path = {}

# "rte,mnli,mrpc,sst2,qqp,qnli,cola,stsb"
task_id_map = {"rte": 0, "mnli": 1, "mrpc": 2, "sst2": 3, "qqp": 4, "qnli": 5, "cola": 6, "stsb": 7,
               "boolq":0, "obqa":1, "piqa":2, "arc_e":3, "arc_c":4,"siqa":5,"winogrande":6,"hellaswag":7}
# task_id_map = {"boolq":0, "obqa":1, "piqa":2, "arc-e":3, "arc-c":4}


def evaluate_qwen(model,tokenizer,dataset,task_name,batch_size=16,label_length=8):
    predictions = []
    labels = dataset['label']
    prompts = dataset['prompt']
    task_id = task_id_map[task_name]
    task_ids = torch.ones(batch_size, dtype=torch.long) * task_id
    model.current_task_ids = task_ids
    model.set_current_task_ids_to_layers()

    tokenizer.padding_side = "left"
    reason = pipeline(
        task="text-generation",
        model=model,
        tokenizer=tokenizer,
        return_full_text=False,
    )   
    for i in range(0, len(prompts), batch_size):
        batch_prompts = prompts[i:i + batch_size]
        current_batch_size = len(batch_prompts)
        
        task_ids = torch.full((current_batch_size,), task_id, dtype=torch.long)
        # model.current_task_ids = task_ids
        # model.set_current_task_ids_to_layers()
        with torch.no_grad():
            outputs = reason(batch_prompts,batch_size=current_batch_size,max_new_tokens=label_length,do_sample=False)
            batch_predictions = [o[0]["generated_text"].strip() for o in outputs]   # type:ignore
            predictions.extend(batch_predictions)
    print(f"The predicitons of task {task_name} is \n{predictions[:100]}")
    # print(f"The labels of task {task_name} is \n{labels}")
    
    if task_name == "cola":
        label_map = {"acceptable": 1, "unacceptable": 0}
        try:
            labels_int = [label_map[str(label).lower()] for label in labels]
            predictions_int = [label_map.get(str(pred).lower(), 0) for pred in predictions]
            return {"mcc": matthews_corrcoef(labels_int, predictions_int)}
        except KeyError as e:
            print(f"Invalid label found in COLA: {e}")
            return {"mcc": 0.0}
        
    elif task_name == "stsb":
        def safe_float(x):
            try:
                return float(x)
            except (ValueError, TypeError):
                return 2.5
        labels_float = [safe_float(l) for l in labels]
        preds_float = [safe_float(p) for p in predictions]
        return {"pearson": pearsonr(labels_float, preds_float)[0]}
    
    elif task_name in ["mrpc", "qqp"]:
        correct = sum(1 for p, l in zip(predictions, labels) if str(p) == str(l))
        acc = correct / len(labels)
        f1 = f1_score(
            [str(l) for l in labels],
            [str(p) for p in predictions],
            average='macro'
        )
        return {"accuracy": acc, "f1": f1}
    else:
        correct = sum(1 for p, l in zip(predictions, labels) if str(p) == str(l))
        return {"accuracy": correct / len(labels)}

from transformers import pipeline

def eval_main_qwen(model,tokenizer,task_to_evaluate,batch_size=32):
    model.eval()
    results = {}
    for task_name in task_to_evaluate:
        print(valid_path[task_name])
        dataset = load_and_concatenate_datasets(valid_path[task_name])
        print(f"Evaluating {task_name}...")
        score = evaluate_qwen(model, tokenizer, dataset, task_name,batch_size=batch_size)
        results[task_name] = score
        print(f"{task_name.upper()} Results: {score}")
        torch.cuda.empty_cache()

    print("\nFinal Results:")
    for task, metrics in results.items():
        print(f"{task.upper()}: {metrics}")

    selected_scores = []
    for task, metrics in results.items():
        if 'accuracy' in metrics:
            selected_scores.append(metrics['accuracy'])
        elif 'mcc' in metrics:
            selected_scores.append(metrics['mcc'])
        elif 'pearson' in metrics:
            selected_scores.append(metrics['pearson'])
        else:
            raise ValueError(f"No supported metric found for task {task}: {metrics}")

    avg_score = sum(selected_scores) / len(selected_scores)
    print(f"\nAverage score across tasks: {avg_score:.4f}")
    return avg_score

def eval_main_qwen2(model,tokenizer,task_to_evaluate,valid_dataset,batch_size=32,label_length=8):
    model.eval()
    results = {}
    for task_name in task_to_evaluate:
        dataset = valid_dataset[task_name]
        print(f"Evaluating {task_name}...")
        score = evaluate_qwen(model, tokenizer, dataset, task_name,batch_size=batch_size,label_length=label_length)
        results[task_name] = score
        print(f"{task_name.upper()} Results: {score}")
        torch.cuda.empty_cache()

    print("\nFinal Results:")
    for task, metrics in results.items():
        print(f"{task.upper()}: {metrics}")

    selected_scores = []
    for task, metrics in results.items():
        if 'accuracy' in metrics:
            selected_scores.append(metrics['accuracy'])
        elif 'mcc' in metrics:
            selected_scores.append(metrics['mcc'])
        elif 'pearson' in metrics:
            selected_scores.append(metrics['pearson'])
        else:
            raise ValueError(f"No supported metric found for task {task}: {metrics}")

    avg_score = sum(selected_scores) / len(selected_scores)
    print(f"\nAverage score across tasks: {avg_score:.4f}")
    return avg_score
