import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import numpy as np
from models.base import MultitaskLearning, WeightMethod
from typing import Tuple

# https://arxiv.org/abs/1803.10704 MTAN

# https://github.com/lorenmt/mtan/blob/master/im2im_pred/model_segnet_mtan.py

def conv3x3(in_planes, out_planes, stride=1):
    """
    Creates a 3x3 convolutional layer with padding.
    
    Args:
        in_planes (int): Number of input channels.
        out_planes (int): Number of output channels.
        stride (int): Stride of the convolution. Default is 1.
    
    Returns:
        nn.Conv2d: A 3x3 convolutional layer.
    """
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True)

def conv_init(m):
    """
    Initializes the weights of convolutional layers and batch normalization layers.
    
    Args:
        m (nn.Module): A layer in the neural network.
    """
    if isinstance(m, nn.Conv2d):
        init.xavier_uniform_(m.weight, gain=np.sqrt(2))
        init.constant_(m.bias, 0)
    elif isinstance(m, nn.BatchNorm2d):
        init.constant_(m.weight, 1)
        init.constant_(m.bias, 0)

class WideBasicBlock(nn.Module):
    """
    A basic building block for the WideResNet architecture, which consists of two 
    convolutional layers with batch normalization and ReLU activation.
    
    Args:
        in_planes (int): Number of input channels.
        planes (int): Number of output channels.
        stride (int): Stride of the convolution. Default is 1.
    """
    def __init__(self, in_planes, planes, stride=1):
        super(WideBasicBlock, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_planes)
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True)

        # Shortcut connection to match dimensions if necessary
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True),
            )

    def forward(self, x):
        """
        Forward pass through the WideBasicBlock.
        
        Args:
            x (torch.Tensor): Input tensor.
        
        Returns:
            torch.Tensor: Output tensor after applying the block.
        """
        out = self.conv1(F.relu(self.bn1(x)))
        out = self.conv2(F.relu(self.bn2(out)))
        out += self.shortcut(x)
        return out
    
class DynamicWeightAverage(WeightMethod):
    """Dynamic Weight Average from `End-to-End Multi-Task Learning with Attention`.
    Modification of: https://github.com/lorenmt/mtan/blob/master/im2im_pred/model_segnet_split.py#L242
    """

    def __init__(
        self, n_tasks, device: torch.device, iteration_window: int = 25, temp=2.0
    ):
        """

        Parameters
        ----------
        n_tasks :
        iteration_window : 'iteration' loss is averaged over the last 'iteration_window' losses
        temp :
        """
        super().__init__(n_tasks, device=device)
        self.iteration_window = iteration_window
        self.temp = temp
        self.running_iterations = 0
        self.costs = np.ones((iteration_window * 2, n_tasks), dtype=np.float32)
        self.weights = np.ones(n_tasks, dtype=np.float32)

    def get_weighted_loss(self, losses, **kwargs):

        cost = losses.detach().cpu().numpy()

        # update costs - fifo
        self.costs[:-1, :] = self.costs[1:, :]
        self.costs[-1, :] = cost

        if self.running_iterations > self.iteration_window:
            ws = self.costs[self.iteration_window :, :].mean(0) / self.costs[
                : self.iteration_window, :
            ].mean(0)
            self.weights = (self.n_tasks * np.exp(ws / self.temp)) / (
                np.exp(ws / self.temp)
            ).sum()

        task_weights = torch.from_numpy(self.weights.astype(np.float32)).to(
            losses.device
        )
        loss = (task_weights * losses).mean()

        self.running_iterations += 1

        return loss, dict(weights=task_weights)

#
class MTAN(MultitaskLearning):
    """
    WideResNet model with task-specific attention mechanisms for multi-task learning.
    
    Args:
        depth (int): Depth of the WideResNet.
        widen_factor (int): Widening factor for the network width.
        tasks (dict): A dictionary mapping task names to the number of classes for each task.
    """
    def __init__(self,
                 cur_encoder, # don't need encoder in this case
                 tasks_name_to_cls_num: dict,
                 lr=0.001,
                 depth=10, # used this for visual decathlon in original code
                 widen_factor=1, # used this for visual decathlon in original code,
                 **kwargs):
        super(MTAN, self).__init__(encoder=None,
                                   tasks_name_to_cls_num=tasks_name_to_cls_num,
                                   lr=lr,
                                   cls_output_dim=2)
        self.in_planes = 16
        num_blocks = (depth - 4) // 6
        widen_factor = widen_factor
        filters = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor]
        self.tasks = tasks_name_to_cls_num
        # Number of stages in the encoder
        self.num_stages = len(filters)
        self.DWA = DynamicWeightAverage(len(self.tasks), device='cuda')

        # Shared convolutional layers (common for all tasks)
        self.conv1 = conv3x3(3, filters[0])
        self.layer1 = self._make_layer(WideBasicBlock, filters[1], num_blocks, stride=2)
        self.layer2 = self._make_layer(WideBasicBlock, filters[2], num_blocks, stride=2)
        self.layer3 = self._make_layer(WideBasicBlock, filters[3], num_blocks, stride=2)
        self.bn1 = nn.BatchNorm2d(filters[3])

        # Task-specific layers: attention modules and classifiers
        self.task_names = list(self.tasks.keys())
        self.num_tasks = len(self.tasks)

        self.attention_modules = nn.ModuleDict()
        self.encoder_blocks = nn.ModuleDict()
        self.classifiers = nn.ModuleDict()

        for task_name, num_classes in self.tasks.items():
            # Define task-specific classifier
            self.classifiers[task_name] = nn.Sequential(
                nn.Linear(filters[-1], num_classes))

            # Define attention layers and encoder blocks for each task
            attention_layers = nn.ModuleList([self._make_attention_layer(filters[0], filters[0], filters[0])])
            encoder_layers = nn.ModuleList([self._make_conv_layer(filters[0], filters[1])])

            for stage_idx in range(1, self.num_stages):
                attention_layers.append(self._make_attention_layer(2 * filters[stage_idx], filters[stage_idx], filters[stage_idx]))
                if stage_idx < self.num_stages - 1:
                    encoder_layers.append(self._make_conv_layer(filters[stage_idx], filters[stage_idx + 1]))
                else:
                    encoder_layers.append(self._make_conv_layer(filters[stage_idx], filters[stage_idx]))

            # Store task-specific modules in ModuleDict for easy access
            self.attention_modules[task_name] = attention_layers
            self.encoder_blocks[task_name] = encoder_layers
        
        self.optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
            
        

    def _make_layer(self, block, planes, num_blocks, stride):
        """
        Creates a WideResNet layer composed of several WideBasicBlock blocks.
        
        Args:
            block (nn.Module): The block type (WideBasicBlock) to be used in the layer.
            planes (int): Number of output channels for the blocks.
            num_blocks (int): Number of blocks in the layer.
            stride (int): Stride for the first block in the layer.
        
        Returns:
            nn.Sequential: A sequence of WideBasicBlock layers.
        """
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes
        return nn.Sequential(*layers)

    def _make_conv_layer(self, in_channels, out_channels):
        """
        Creates a convolutional block with a 3x3 convolution, BatchNorm, and ReLU activation.
        
        Args:
            in_channels (int): Number of input channels.
            out_channels (int): Number of output channels.
        
        Returns:
            nn.Sequential: A convolutional block.
        """
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def _make_attention_layer(self, in_channels, mid_channels, out_channels):
        """
        Creates an attention block with 1x1 convolutions, BatchNorm, and ReLU/Sigmoid activations.
        
        Args:
            in_channels (int): Number of input channels.
            mid_channels (int): Number of intermediate channels.
            out_channels (int): Number of output channels.
        
        Returns:
            nn.Sequential: An attention block.
        """
        return nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=1),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=1),
            nn.BatchNorm2d(out_channels),
            nn.Sigmoid(),
        )

    def forward_task(self, x, task_name):
        """
        Forward pass of the model for a specific task.
        
        Args:
            x (torch.Tensor): Input tensor (image batch).
            task_name (str): The name of the task for which the forward pass is being performed.
        
        Returns:
            torch.Tensor: The output prediction for the specified task.
        """
        # Shared encoder (common across all tasks)
        shared_features = [0] * self.num_stages
        shared_features[0] = self.conv1(x)
        shared_features[1] = self.layer1(shared_features[0])
        shared_features[2] = self.layer2(shared_features[1])
        shared_features[3] = F.relu(self.bn1(self.layer3(shared_features[2])))

        # Task-specific attention modules and encoder blocks
        attention_layers = self.attention_modules[task_name]
        encoder_blocks = self.encoder_blocks[task_name]
        attention_maps = [0] * self.num_stages

        for stage_idx in range(self.num_stages):
            if stage_idx == 0:
                # Apply first attention layer and encoder block
                attention_maps[stage_idx] = attention_layers[stage_idx](shared_features[stage_idx]) * shared_features[stage_idx]
                attention_maps[stage_idx] = encoder_blocks[stage_idx](attention_maps[stage_idx])
                attention_maps[stage_idx] = F.max_pool2d(attention_maps[stage_idx], kernel_size=2, stride=2)
            else:
                # Apply subsequent attention layers and encoder blocks
                attention_maps[stage_idx] = attention_layers[stage_idx](
                    torch.cat([shared_features[stage_idx], attention_maps[stage_idx - 1]], dim=1)
                ) * shared_features[stage_idx]
                attention_maps[stage_idx] = encoder_blocks[stage_idx](attention_maps[stage_idx])
                if stage_idx < self.num_stages - 1:
                    attention_maps[stage_idx] = F.max_pool2d(attention_maps[stage_idx], kernel_size=2, stride=2)

        # Global average pooling and classification
        # adaptive_avg_pool2d is used to handle different input sizes
        out = F.adaptive_avg_pool2d(attention_maps[-1], output_size=(1, 1))
        out = out.squeeze(-1).squeeze(-1)
        out = self.classifiers[task_name](out)
        return out
    
    def forward(self, x):
        """
        Forward pass of the model for all tasks.
        
        Args:
            x (torch.Tensor): Input tensor (image batch).
        
        Returns:
            dict: A dictionary mapping task names to the output predictions for each task.
        """
        outputs = {}
        for task_name in self.task_names:
            outputs[task_name] = self.forward_task(x, task_name)
        return outputs
    
    def compute_loss(self,
                    inputs: torch.Tensor,
                    labels: dict,
                    loss_func: nn.Module) -> torch.Tensor:
        self.optimizer.zero_grad()
        losses = []
        outputs = self.forward(inputs)
        for task_name in labels.keys():
            loss = loss_func(outputs[task_name], labels[task_name])
            losses.append(loss)
        
        self.DWA.backward(torch.stack(losses))
        self.optimizer.step()
        return sum(losses)

    def compute_loss_nograd(self,
                        inputs: torch.Tensor,
                        labels: dict,
                        loss_func: nn.Module) -> torch.Tensor:
        tot_loss = 0
        with torch.no_grad():
            outputs = self.forward(inputs)
            for task_name in labels.keys():
                loss = loss_func(outputs[task_name], labels[task_name])
                tot_loss += loss
        return tot_loss
    
    def calculate_accuraciess(self, 
                              valid_loader: torch.utils.data.DataLoader,
                              tasks_name: Tuple[str],
                              device: torch.device) -> dict:

        # for task_name in tasks_name:
        #     assert predictors[task_name]
        correct = [0] * len(tasks_name)
        total = len(valid_loader.dataset)
        self.eval()
        result = dict()
        with torch.no_grad():
            for sample in valid_loader:
                images = sample["image"].to(device)
                for idx, task_name in enumerate(tasks_name):
                    cur_task_y = sample[task_name].type(torch.LongTensor).to(device)
                    outputs = self.forward_task(images, task_name)
                    # outputs = predictors[task_name](z)
                    _, predicted = torch.max(outputs.data, 1)
                    correct[idx] += (predicted == cur_task_y).sum().item()
        for idx, task_name in enumerate(tasks_name):
            result[task_name] = correct[idx] / total
        return result
    
