from ast import arg
import logging
from pyexpat import model
import torch
import torch.nn as nn
from torch import Tensor
import torch.nn.functional as F
from torch.optim import Adam
from torch.nn.parameter import Parameter
from tqdm import tqdm
from copy import deepcopy
from typing import Dict, Iterator, List, Optional
from src.datasets.common import maybe_dictionarize
from src.route_merged_model import RouteMergedModel
from src.tasks.shortest_route_mask import build_model

log = logging.getLogger(__name__)

# 简单定义 StateDict 为 dict 类型
StateDict = dict

def compute_sr_classification_heads_single(
    task_vector_pre: StateDict,
    pretrained_model: nn.Module,
    merged_state_dict: StateDict,
    classification_head_pre: nn.Module,
    lr: float = 0.01,
    max_epochs: int = 100,
    pre_task_dataloader: Optional[torch.utils.data.DataLoader] = None,
    device: str = "cuda:1",
):
    
    for p in pretrained_model.parameters():
            p.detach_().requires_grad_(False)    
    
    # 2. 构造辅助模型
    model_pre = build_model(pretrained_model, task_vector_pre, device)

    model_merged = deepcopy(pretrained_model)
    model_merged.load_state_dict(merged_state_dict)

    classification_head_pre_student = deepcopy(classification_head_pre)

    for param in model_merged.parameters():
        param.requires_grad = False

    classification_head_pre_student.weight.requires_grad = True
    classification_head_pre.weight.requires_grad = False
    optimizer = Adam(
        [
            {'params': classification_head_pre_student.parameters(), 'lr': lr, 'betas': (0.9, 0.999), 'weight_decay': 0.},
        ]
    )

    model_pre.to(device)
    model_merged.to(device)

    classification_head_pre.to(device)

    model_pre.eval()
    model_merged.eval()
    classification_head_pre.train()

    pre_iter = iter(pre_task_dataloader)

    criterion = nn.CrossEntropyLoss()
    
    best_loss = float('inf')
    
    pbar = tqdm(range(max_epochs), desc="Training classification heads")
    for epoch in pbar:
        
        try:
            batch = next(pre_iter)
        except StopIteration:
            pre_iter = iter(pre_task_dataloader)
            batch = next(pre_iter)
        
        batch = maybe_dictionarize(batch)
        x, y = batch["images"].to(device), batch["labels"].to(device)
        
        with torch.no_grad():
            distribution_pre = model_merged(x)
        logits_student = classification_head_pre_student(distribution_pre)

        with torch.no_grad():
            logits_teacher = classification_head_pre(model_pre(x))

        loss = criterion(logits_student, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        pbar.set_postfix({
            "epoch": epoch,
            "loss": f"{loss.item():.4f}",
        })

        if loss.item() < best_loss:
            best_loss = loss.item()
    
    classification_head_pre_student.weight.requires_grad = False

    return classification_head_pre_student