from src.models.task_vectors import NonLinearTaskVector
from src.models import ImageEncoder
from src.utils.variables_and_paths import get_finetuned_path
import torch
import os


def create_task_vector(dataset_name, ptm_check, args):
    """Create task vector θ_task = θ_expert - θ_0"""
    dataset_val = dataset_name + "Val"
    ft_path = get_finetuned_path(
        args.model_location, dataset_val, model=args.model)

    if os.path.exists(ft_path):
        print(f"Creating task vector for task {dataset_name}")
        ft_check = torch.load(ft_path, map_location="cpu")
        task_vector = NonLinearTaskVector(args.model, ptm_check, ft_check)
        return task_vector
    else:
        print(f"Warning: Fine-tuned model for task {dataset_name} not found")
        return None


def apply_task_vector(pretrained_model, task_vector, scaling_coef=1.0):
    """Apply task vector to pretrained model: θ_merged = θ_0 + α*θ_task"""
    if task_vector is None:
        return pretrained_model

    # Copy pretrained model parameters to avoid modifying the original model
    merged_model = ImageEncoder(task_vector.model_name)
    merged_model.load_state_dict(pretrained_model.state_dict())

    # Apply task vector
    merged_state_dict = merged_model.state_dict()
    for key in merged_state_dict:
        if key in task_vector.vector:
            merged_state_dict[key] = merged_state_dict[key] + \
                scaling_coef * task_vector.vector[key]

    merged_model.load_state_dict(merged_state_dict)
    return merged_model

def set_20_dataset_order(order_num):
    """Set the processing order for 20 datasets"""
    global DATASETS
    
    # Mapping from task ID to dataset name
    id_to_dataset = {
        1: "SUN397",
        2: "Cars",
        3: "RESISC45",
        4: "EuroSAT",
        5: "SVHN",
        6: "GTSRB",
        7: "MNIST",
        8: "DTD",
        9: "Flowers102",
        10: "PCAM",
        11: "FER2013",
        12: "OxfordIIITPet",
        13: "STL10",
        14: "CIFAR100",
        15: "CIFAR10",
        16: "Food101",
        17: "FashionMNIST",
        18: "EMNIST",
        19: "KMNIST",
        20: "RenderedSST2"
    }
    
    # Define 10 different task orders based on the numerical sequences in the chart
    order_sequences = {
        # Keep order 2 unchanged
        1: [6, 5, 7, 2, 3, 8, 1, 4, 10, 14, 13, 11, 12, 9, 20, 15, 16, 19, 18, 17],
        2: [7, 8, 5, 4, 2, 6, 3, 1, 13, 12, 9, 14, 10, 11, 15, 16, 17, 20, 18, 19],
        3: [3, 8, 2, 1, 5, 7, 6, 4, 9, 11, 13, 10, 12, 14, 15, 16, 20, 19, 18, 17],
        4: [4, 7, 8, 2, 1, 6, 5, 3, 11, 10, 12, 13, 14, 9, 17, 19, 18, 15, 20, 16],
        5: [2, 6, 4, 8, 1, 7, 5, 3, 10, 13, 9, 11, 14, 12, 17, 19, 18, 16, 20, 15],
        6: [4, 7, 6, 1, 5, 8, 2, 3, 14, 11, 12, 9, 10, 13, 16, 20, 15, 17, 18, 19],
        7: [1, 6, 4, 8, 2, 3, 7, 5, 9, 14, 13, 12, 10, 11, 19, 20, 17, 15, 16, 18],
        8: [7, 8, 2, 6, 5, 1, 4, 3, 10, 9, 13, 11, 14, 12, 15, 17, 20, 19, 16, 18],
        9: [5, 7, 1, 3, 4, 2, 6, 8, 10, 11, 14, 13, 12, 9, 17, 15, 18, 19, 20, 16],
        10: [1, 2, 3, 7, 5, 4, 6, 8, 11, 10, 12, 13, 9, 14, 15, 19, 16, 20, 18, 17]
    }
    
    # Get the selected numerical sequence and convert to dataset names
    selected_sequence = order_sequences.get(order_num, order_sequences[1])  # Default to sequence 1
    dataset_sequence = [id_to_dataset[id_num] for id_num in selected_sequence]
    
    DATASETS = dataset_sequence
    return DATASETS


DATASETS = set_20_dataset_order(2)