import logging
from operator import ge
from pyexpat import model
import torch
import torch.nn as nn
from torch import Tensor
import torch.nn.functional as F
from torch.optim import Adam
from torch.nn.parameter import Parameter
from tqdm import tqdm
from copy import deepcopy
from typing import Dict, Iterator, List, Optional
from src.datasets.common import maybe_dictionarize
from src.route_merged_model import RouteMergedModel
import geomloss

log = logging.getLogger(__name__)

StateDict = dict


class ShortestRouteMask(nn.Module):
    def __init__(
        self,
        state_dict: StateDict,
        init_value: float = 1.0,
    ):
        super().__init__()
        # Initialize a learnable mask for each parameter in the state dict
        masks = {}
        for k, v in state_dict.items():
            masks[k] = nn.Parameter(torch.ones_like(v) * init_value, requires_grad=True)
        self.masks = masks

    def _draw_mask(self, binary_mask: bool = False):
        # Return a deterministic mask of all ones (can be replaced with stochastic sampling if needed)
        ot_masks = {k: torch.ones_like(param) for k, param in self.masks.items()}
        return ot_masks

    def parameters(self, recurse: bool = True) -> Iterator[torch.nn.Parameter]:
        # Yield parameters for optimization
        return self.masks.values()

    def to(self, device):
        # Move mask tensors to specified device
        super().to(device)
        for k in self.masks:
            self.masks[k].data = self.masks[k].data.to(device)
        return self


def compute_sr_mask(
    task_vector_pre: StateDict,
    task_vector_post: StateDict,
    pretrained_model: nn.Module,
    masks_pre: Dict[str, Tensor],
    masks_post: Dict[str, Tensor],
    lr: float = 0.01,
    max_epochs: int = 100,
    pre_task_dataloader: Optional[torch.utils.data.DataLoader] = None,
    post_task_dataloader: Optional[torch.utils.data.DataLoader] = None,
    mask_alpha: float = 0.5,
    device: str = "cuda:1",
):
    # Freeze pretrained model parameters
    for p in pretrained_model.parameters():
        p.detach_().requires_grad_(False)

    # Re-initialize masks as learnable parameters
    for k in masks_pre.keys():
        masks_pre[k] = Parameter(torch.ones_like(masks_pre[k]), requires_grad=True)
        masks_post[k] = Parameter(torch.ones_like(masks_post[k]), requires_grad=True)

    # Construct task-specific models and the merged model
    model_pre = build_model(pretrained_model, task_vector_pre, device)
    model_post = build_model(pretrained_model, task_vector_post, device)

    model_merged = RouteMergedModel(
        pretrained_model, 
        task_vector_pre, 
        task_vector_post, 
        masks_pre, 
        masks_post, 
        mask_alpha,
        device
    )

    model_pre.to(device).eval() 
    model_post.to(device).eval()  
    model_merged.to(device).train()  

    # Setup optimizer for the learnable masks
    optimizer = Adam(
        [
            {'params': model_merged.masks_pre.values(), 'lr': lr, 'betas': (0.9, 0.999), 'weight_decay': 0.},
            {'params': model_merged.masks_post.values(), 'lr': lr, 'betas': (0.9, 0.999), 'weight_decay': 0.},
        ]
    )

    # Create iterators for data loaders
    pre_iter = iter(pre_task_dataloader)
    post_iter = iter(post_task_dataloader)

    best_loss = float('inf')
    best_masks_pre = None
    best_masks_post = None

    # Training loop for masks
    pbar = tqdm(range(max_epochs), desc="Training masks")
    for epoch in pbar:
        model_merged.train()

        # Alternate training between pre and post task masks
        train_pre = (epoch % 2 == 0)

        for param in masks_pre.values():
            param.requires_grad = train_pre
        for param in masks_post.values():
            param.requires_grad = not train_pre

        # Fetch next batch from the appropriate dataset
        try:
            batch = next(pre_iter if train_pre else post_iter)
        except StopIteration:
            pre_iter = iter(pre_task_dataloader)
            post_iter = iter(post_task_dataloader)
            batch = next(pre_iter if train_pre else post_iter)

        batch = maybe_dictionarize(batch)
        x, y = batch["images"].to(device), batch["labels"].to(device)

        # Forward pass through merged model
        model_merged.merge_weights()
        logits_student = model_merged(x)

        # Get output from corresponding teacher model
        if train_pre:
            logits_teacher = model_pre(x).detach()
        else:
            logits_teacher = model_post(x).detach()

        # Use Sinkhorn divergence (Optimal Transport loss)
        loss = geomloss.SamplesLoss(loss="sinkhorn", p=2, blur=0.05)(logits_student, logits_teacher)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        pbar.set_postfix({
            "epoch": epoch,
            "loss": f"{loss.item():.4f}",
        })

        # Save the best masks based on loss
        if loss.item() < best_loss:
            best_loss = loss.item()
            best_masks_pre = {k: v.detach().clone() for k,v in masks_pre.items()}
            best_masks_post = {k: v.detach().clone() for k,v in masks_post.items()}

    # Apply best learned masks
    if best_masks_pre is not None:
        for k in model_merged.masks_pre:
            model_merged.masks_pre[k].data = best_masks_pre[k]
        for k in model_merged.masks_post:
            model_merged.masks_post[k].data = best_masks_post[k]

    # Final merge with best masks
    model_merged.merge_weights()

    # Return the merged model's state dictionary
    merged_state_dict = {
        k: v.detach().cpu() for k, v in model_merged.merged_state_dict.items()
    }

    # del model_pre
    # del model_post
    # del model_merged
    # del pretrained_model  # 如果有 deepcopy
    # torch.cuda.empty_cache()

    return merged_state_dict


def build_model(
    pretrained_model: nn.Module,
    task_state_dict: StateDict,
    device: str = "cuda:1",
):
    # Construct model by adding task-specific deltas to pretrained weights
    model = deepcopy(pretrained_model)
    model_sd = model.state_dict()

    for n, p in task_state_dict.items():
        if n in model_sd:
            model_sd[n] = model_sd[n].to(device) + p.to(device)

    model.load_state_dict(model_sd)
    model = model.to(device)
    return model
