import torch
import numpy as np


class Weight_acc:
    # cat_mode = 'targets' or 'out'
    def __init__(self, num_class, tasks, eoss_temperature=1.0):
        self.tasks = tasks
        if not tasks:  # Handle empty task list
            raise ValueError("Task list cannot be empty for Weight_acc")
        self.num_class = num_class
        self.eoss_temperature = eoss_temperature
        # Initialize weights on CPU, move later if needed
        self.weigh = {t: torch.ones(num_class, device='cpu') for t in tasks}
        self.max_weight_task = [tasks[0] for i in range(num_class)]
        self.weigh_save_list = {t: []
                                for t in tasks}  # Keep for potential future use
        self.weigh_save_list['max'] = []
        self.weigh_save_list['conflict'] = np.zeros(
            [400, 2])  # Assuming max 400 epochs

    def update(self, all_preds_np, all_targets_np):
        """
        Update EOSS weights based on the overall accuracy of each task
        on each class over the entire validation set.
        Args:
            all_preds_np (dict): Dictionary mapping task names to NumPy arrays
                                 of predicted labels for the entire validation set.
            all_targets_np (np.array): NumPy array of true labels for the
                                       entire validation set.
        """
        if not self.tasks:
            print("Warning: No tasks defined for Weight_acc update.")
            return

        if not isinstance(all_targets_np, np.ndarray):
            try:
                all_targets_np = np.array(all_targets_np)
            except Exception as e:
                print(f"Error converting all_targets_np to NumPy array: {e}")
                return

        # --- Determine device dynamically (use first available weight tensor's device) --- #
        # Default to CPU if weights haven't been moved yet
        device = next(iter(self.weigh.values())
                      ).device if self.weigh else torch.device('cpu')

        # Store accuracies {task: [acc_cls0, acc_cls1, ...]}
        task_accuracies_per_class = {}

        for t in self.tasks:
            # --- Validation of Input Data --- #
            if t not in all_preds_np:
                print(
                    f"Warning: Task '{t}' not found in all_preds_np during weight update. Skipping.")
                task_accuracies_per_class[t] = np.zeros(
                    self.num_class)  # Assume zero accuracy
                continue
            if not isinstance(all_preds_np[t], np.ndarray):
                try:
                    all_preds_np[t] = np.array(all_preds_np[t])
                except Exception as e:
                    print(
                        f"Error converting preds for task '{t}' to NumPy array: {e}. Skipping.")
                    task_accuracies_per_class[t] = np.zeros(
                        self.num_class)  # Assume zero accuracy
                    continue
            # Check only first dimension
            if all_preds_np[t].shape[0] != all_targets_np.shape[0]:
                print(
                    f"Warning: Shape mismatch for task '{t}'. Preds: {all_preds_np[t].shape[0]}, Targets: {all_targets_np.shape[0]}. Skipping update for this task.")
                task_accuracies_per_class[t] = np.zeros(self.num_class)
                continue
            # --- End Validation --- #

            # --- Calculate per-class accuracy for this task over the whole set --- #
            per_class_correct = np.zeros(self.num_class)
            per_class_total = np.zeros(self.num_class)
            current_preds = all_preds_np[t]

            for i in range(self.num_class):
                class_mask = (all_targets_np == i)
                per_class_total[i] = class_mask.sum()
                if per_class_total[i] > 0:
                    task_preds_for_class = current_preds[class_mask]
                    per_class_correct[i] = (task_preds_for_class == i).sum()

            # Calculate accuracy, handle division by zero
            per_class_accuracies = np.divide(per_class_correct, per_class_total,
                                             out=np.zeros_like(
                                                 per_class_correct, dtype=float),
                                             where=per_class_total != 0)
            task_accuracies_per_class[t] = per_class_accuracies
            # --- End Accuracy Calculation --- #

        # --- Update self.weigh (e.g., using exp(accuracy)) and self.max_weight_task --- #
        # Initialize sum_exp on the correct device
        sum_exp = torch.zeros(self.num_class, device=device)
        temp_weights = {}  # Store exp(acc) before normalization

        for t in self.tasks:
            if t in task_accuracies_per_class:
                acc_tensor = torch.tensor(
                    task_accuracies_per_class[t], dtype=torch.float32, device=device)
                # Apply temperature scaling
                if self.eoss_temperature <= 0:  # Prevent division by zero or invalid temp
                    print(
                        f"Warning: Invalid eoss_temperature ({self.eoss_temperature}). Using 1.0 instead.")
                    current_temp = 1.0
                else:
                    current_temp = self.eoss_temperature
                exp_acc = torch.exp(acc_tensor / current_temp)
                temp_weights[t] = exp_acc
                sum_exp += exp_acc
            else:
                # Handle case where task accuracy couldn't be calculated
                temp_weights[t] = torch.zeros(self.num_class, device=device)

        # Softmax step 2: Normalize weights
        safe_sum_exp = sum_exp + 1e-8
        for t in self.tasks:
            if t in temp_weights:
                self.weigh[t] = temp_weights[t] / safe_sum_exp
            # else: self.weigh[t] remains as initialized or from previous step?
            # Let's ensure all tasks in self.weigh are updated/present
            elif t not in self.weigh:
                self.weigh[t] = torch.zeros(
                    self.num_class, device=device)  # Initialize if missing

        # --- Update max_weight_task (Based on non-normalized exp(accuracies) for clarity) --- #
        # Or base it on normalized weights self.weigh? Let's use NON-NORMALIZED (temp_weights)
        # as Softmax normalization might obscure the true best performer if differences are small.
        for i in range(self.num_class):
            # Default to first task if exists
            best_task_for_class = self.tasks[0] if self.tasks else None
            max_exp_acc = -float('inf')
            if best_task_for_class and best_task_for_class in temp_weights:  # Check default exists
                max_exp_acc = temp_weights[best_task_for_class][i].item()

            for t in self.tasks:
                if t in temp_weights:
                    current_exp_acc = temp_weights[t][i].item()
                    if current_exp_acc > max_exp_acc:
                        max_exp_acc = current_exp_acc
                        best_task_for_class = t
            if best_task_for_class is not None:  # Update only if a best task was found
                self.max_weight_task[i] = best_task_for_class
        # --- End Update --- #

    def cat_out(self, logits):
        """Weighted sum of softmax probabilities."""
        # ... (Implementation from previous version with device handling) ...
        if not self.tasks:
            return None  # Or raise error
        # --- Ensure device consistency ---
        first_task_logit = logits.get(self.tasks[0], None)
        if isinstance(first_task_logit, torch.Tensor):
            device = first_task_logit.device
        elif self.weigh:  # If logits empty, try device from weights
            device = next(iter(self.weigh.values())).device
        else:  # Fallback
            device = torch.device(
                'cuda' if torch.cuda.is_available() else 'cpu')
        # Ensure weights and logits are on the same device
        output_sum = None
        for t in self.tasks:
            if t not in logits or not isinstance(logits[t], torch.Tensor):
                continue
            current_logits = logits[t].to(device)
            current_weights = self.weigh[t].to(device)
            if output_sum is None:  # Initialize output_sum with correct shape and device
                output_sum = torch.zeros_like(current_logits)
            # Ensure weights broadcast correctly (N, C) * (C,) -> (N, C)
            output_sum += (current_weights.unsqueeze(0) *
                           torch.softmax(current_logits, dim=1))
        return output_sum

    def cat_targets(self, logits, targets, epoch):
        """Selects prediction based on max_weight_task for conflicts."""
        # ... (Implementation from previous version with device handling) ...
        if not self.tasks:
            return None  # Or raise error
        # --- Ensure device consistency ---
        if isinstance(targets, torch.Tensor):
            device = targets.device
        elif self.tasks and self.tasks[0] in logits and isinstance(logits[self.tasks[0]], torch.Tensor):
            device = logits[self.tasks[0]].device
        elif self.weigh:
            device = next(iter(self.weigh.values())).device
        else:
            device = torch.device(
                'cuda' if torch.cuda.is_available() else 'cpu')
        if not isinstance(targets, torch.Tensor):
            targets = torch.tensor(targets, device=device)
        else:
            targets = targets.to(device)
        # Ensure logits are tensors and on the correct device
        out_puts = {}
        first_task_name = None
        for t in self.tasks:
            if t not in logits or not isinstance(logits[t], torch.Tensor):
                continue
            # Get first valid task
            first_task_name = t if first_task_name is None else first_task_name
            logits[t] = logits[t].to(device)
            out_puts[t] = torch.argmax(logits[t], dim=1)
        # --- End device consistency ---
        if not out_puts or first_task_name is None:
            return None  # No predictions to work with
        # Initialize catout based on the first task's predictions
        catout = torch.zeros_like(out_puts[first_task_name])
        # Find where all available tasks agree
        if len(out_puts) > 1:
            task_iter = iter(out_puts.keys())
            t1 = next(task_iter)  # Already checked first_task_name exists
            t2 = next((tk for tk in task_iter if tk in out_puts), None)
            if t2 is None:  # Only one valid task found
                same_index = torch.ones_like(catout, dtype=torch.bool)
            else:
                same_index = (out_puts[t1] == out_puts[t2])
                for t_next in task_iter:
                    if t_next in out_puts:
                        same_index = same_index & (
                            out_puts[t1] == out_puts[t_next])
        else:
            same_index = torch.ones_like(catout, dtype=torch.bool)
        # Where they agree, use the prediction (from the first task)
        catout[same_index] = out_puts[first_task_name][same_index]
        # Where they conflict, use the best task for the true target class
        conflict_indices = torch.nonzero(torch.logical_not(
            same_index)).squeeze(1)  # Get 1D indices
        for i in conflict_indices:
            # Get the true class label as int
            true_target_class = targets[i].item()
            if 0 <= true_target_class < self.num_class:
                target_max_task = self.max_weight_task[true_target_class]
                if target_max_task in out_puts:
                    catout[i] = out_puts[target_max_task][i]
                else:
                    # Default to first task
                    catout[i] = out_puts[first_task_name][i]
            else:
                print(
                    f"Warning: Invalid target class {true_target_class} at index {i}. Using default prediction.")
                # Default to first task
                catout[i] = out_puts[first_task_name][i]
        # Return one-hot encoded predictions
        return torch.nn.functional.one_hot(catout, num_classes=self.num_class).float()

    def cuda(self):
        """Move internal weight tensors to CUDA."""
        if torch.cuda.is_available():
            device = torch.device("cuda")
            for t in self.tasks:
                if t in self.weigh:
                    self.weigh[t] = self.weigh[t].to(device)
        return self  # Allow chaining

    def cpu(self):
        """Move internal weight tensors to CPU."""
        device = torch.device("cpu")
        for t in self.tasks:
            if t in self.weigh:
                self.weigh[t] = self.weigh[t].to(device)
        return self  # Allow chaining
