from cProfile import label
import logging
from operator import ge
from pyexpat import model
import torch
import torch.nn as nn
from torch import Tensor
import torch.nn.functional as F
from torch.optim import Adam
from torch.nn.parameter import Parameter
from tqdm import tqdm
from copy import deepcopy
from typing import Dict, Iterator, List, Optional
from src.route_merged_model import RouteMergedModel
import geomloss
from src.task_wise_fusion import *

log = logging.getLogger(__name__)

# Define StateDict as a simple alias for dict
StateDict = dict

class ShortestRouteMask(nn.Module):
    def __init__(
        self,
        state_dict: StateDict,
        init_value: float = 1.0,
    ):
        super().__init__()
        # Initialize a learnable mask for each parameter in the state dict
        masks = {}
        for k, v in state_dict.items():
            masks[k] = nn.Parameter(torch.ones_like(v) * init_value, requires_grad=True)
        self.masks = masks

    def _draw_mask(self, binary_mask: bool = False):
        # Return a deterministic mask of all ones (can be replaced with stochastic sampling if needed)
        ot_masks = {k: torch.ones_like(param) for k, param in self.masks.items()}
        return ot_masks

    def parameters(self, recurse: bool = True) -> Iterator[torch.nn.Parameter]:
        # Yield parameters for optimization
        return self.masks.values()

    def to(self, device):
        # Move mask tensors to specified device
        super().to(device)
        for k in self.masks:
            self.masks[k].data = self.masks[k].data.to(device)
        return self


def compute_sr_mask(
    task_vector_pre: StateDict,
    task_vector_post: StateDict,
    pretrained_model: nn.Module,
    masks_pre: Dict[str, Tensor],
    masks_post: Dict[str, Tensor],
    lr: float = 0.01,
    max_epochs: int = 200,
    pre_task_dataloader: Optional[torch.utils.data.DataLoader] = None,
    post_task_dataloader: Optional[torch.utils.data.DataLoader] = None,
    mask_alpha: float = 0.5,
    device: str = "cuda:1",
    pad_token_id: int = 0,
):
    # Freeze pretrained model parameters
    for p in pretrained_model.parameters():
        p.detach_().requires_grad_(False)

    # Convert original mask tensors into trainable Parameters (initialized to 1)
    for k in masks_pre.keys():
        masks_pre[k] = Parameter(torch.ones_like(masks_pre[k]), requires_grad=True)
        masks_post[k] = Parameter(torch.ones_like(masks_post[k]), requires_grad=True)

    # Construct teacher models
    model_pre = build_model(pretrained_model, task_vector_pre, device)
    model_post = build_model(pretrained_model, task_vector_post, device)

    # Fusion model will perform weighted sum of weights according to masks
    model_merged = RouteMergedModel(
        pretrained_model, 
        task_vector_pre, 
        task_vector_post, 
        masks_pre, 
        masks_post, 
        mask_alpha,
        device
    )
    model_merged.to(device)

    # Move models to GPU/CPU
    model_pre.to(device).eval()    # Teacher models in eval mode
    model_post.to(device).eval()
    model_merged.to(device).train()  # Fusion model in training mode

    # Optimizer for masks of both tasks
    optimizer = Adam(
        [
            {'params': model_merged.masks_pre.values(), 'lr': lr, 'betas': (0.9, 0.999), 'weight_decay': 0.},
            {'params': model_merged.masks_post.values(), 'lr': lr, 'betas': (0.9, 0.999), 'weight_decay': 0.},
        ]
    )

    # Initialize dataloader iterators
    pre_iter = iter(pre_task_dataloader)
    post_iter = iter(post_task_dataloader)

    # Record the best result
    best_loss = float('inf')
    best_masks_pre = None
    best_masks_post = None

    pbar = tqdm(range(max_epochs), desc="Training masks")
    for epoch in pbar:
        model_merged.train()

        # Get batch from each dataset
        try:
            batch_pre = next(pre_iter)
            batch_post = next(post_iter)
        except StopIteration:
            # Reinitialize iterator if at the end
            pre_iter = iter(pre_task_dataloader)
            post_iter = iter(post_task_dataloader)
            batch_pre = next(pre_iter)
            batch_post = next(post_iter)

        # Apply mask to fusion model weights
        model_merged.merge_weights()

        # Truncate padding tokens
        input_ids_pre = batch_pre["input_ids"]
        attention_mask_pre = batch_pre["attention_mask"]
        max_len_pre = input_ids_pre.size(1)
        while torch.all(attention_mask_pre[:, max_len_pre - 1] == 0):
            max_len_pre -= 1
        input_ids_pre = input_ids_pre[:, :max_len_pre]
        attention_mask_pre = attention_mask_pre[:, :max_len_pre]

        input_ids_post = batch_post["input_ids"]
        attention_mask_post = batch_post["attention_mask"]
        max_len_post = input_ids_post.size(1)
        while torch.all(attention_mask_post[:, max_len_post - 1] == 0):
            max_len_post -= 1
        input_ids_post = input_ids_post[:, :max_len_post]
        attention_mask_post = attention_mask_post[:, :max_len_post]

        # Use pad_token_id as decoder_start_token_id for T5
        outputs_student_pre = model_merged(
            input_ids=input_ids_pre,
            attention_mask=attention_mask_pre,
            decoder_input_ids=torch.ones(input_ids_pre.size(0), 1, dtype=torch.long, device=input_ids_pre.device) * pad_token_id,
        )
        logits_student_pre = (outputs_student_pre.logits[:, 0, :]).softmax(1)

        outputs_student_post = model_merged(
            input_ids=input_ids_post,
            attention_mask=attention_mask_post,
            decoder_input_ids=torch.ones(input_ids_post.size(0), 1, dtype=torch.long, device=input_ids_post.device) * pad_token_id,
        )
        logits_student_post = (outputs_student_post.logits[:, 0, :]).softmax(1)

        # Forward pass for teacher models (no gradient computation)
        outputs_teacher_pre = model_pre(
            input_ids=input_ids_pre,
            attention_mask=attention_mask_pre,
            decoder_input_ids=torch.ones(input_ids_pre.size(0), 1, dtype=torch.long, device=input_ids_pre.device) * pad_token_id,
        )
        logits_teacher_pre = (outputs_teacher_pre.logits[:, 0, :].detach()).softmax(1)

        outputs_teacher_post = model_post(
            input_ids=input_ids_post, 
            attention_mask=attention_mask_post,
            decoder_input_ids=torch.ones(input_ids_post.size(0), 1, dtype=torch.long, device=input_ids_post.device) * pad_token_id,
        )
        logits_teacher_post = (outputs_teacher_post.logits[:, 0, :].detach()).softmax(1)

        # Use Sinkhorn divergence for alignment loss
        loss_pre = geomloss.SamplesLoss(loss="sinkhorn", p=2, blur=0.05)(logits_student_pre, logits_teacher_pre)
        loss_post = geomloss.SamplesLoss(loss="sinkhorn", p=2, blur=0.05)(logits_student_post, logits_teacher_post)

        loss = loss_pre + loss_post

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Print progress
        pbar.set_postfix({
            "epoch": epoch,
            "loss": f"{loss.item():.4f}",
        })

        # Save best masks
        if loss.item() < best_loss:
            best_loss = loss.item()
            best_masks_pre = {k: v.detach().clone() for k, v in masks_pre.items()}
            best_masks_post = {k: v.detach().clone() for k, v in masks_post.items()}

    # ============ After training ============

    # Restore best masks if available
    if best_masks_pre is not None:
        for k in model_merged.masks_pre:
            model_merged.masks_pre[k].data = best_masks_pre[k]
        for k in model_merged.masks_post:
            model_merged.masks_post[k].data = best_masks_post[k]

    # Final weight merge
    model_merged.merge_weights()

    # Return final merged state_dict, ready for inference
    merged_state_dict = {
        k: v.detach().cpu() for k, v in model_merged.merged_state_dict.items()
    }

    return merged_state_dict

def build_model(
    pretrained_model: nn.Module,
    task_state_dict: StateDict,
    device: str = "cuda:1",
):
    """
    Construct a model for a specific task based on the pretrained model and task vector.
    Commonly, the state_dict of the pretrained model is copied and added with the task vector.
    """
    model = deepcopy(pretrained_model)
    model_sd = model.state_dict()

    for n, p in task_state_dict.items():
        if n in model_sd:
            # Add directly or use residual-style update if needed
            model_sd[n] = model_sd[n].to(device) + p.to(device)

    model.load_state_dict(model_sd)
    model = model.to(device)
    return model

def softmax_entropy(x: Tensor):
    """
    Computes the softmax entropy of a tensor.

    Args:
        x (Tensor): Input tensor.

    Returns:
        Tensor: Softmax entropy of the input tensor.
    """
    return -(x.softmax(1) * x.log_softmax(1)).sum(1)
