import torch
import torch.nn as nn
import torch.nn.functional as F
from utils.utils import utils_l1_norm, utils_l2_norm, calculate_network_dims, new_utils_l2_norm
import math
# from hat.modules import HATLinear, HATConv2d, TaskIndexedLayerNorm
# from hat import HATPayload


class CReLU(nn.Module):

    def __init__(self, inplace=False):
        super(CReLU, self).__init__()

    def forward(self, x):
        if len(x.shape) == 2:
            x = torch.cat((x,-x),-1)
        elif len(x.shape) == 4:
            x = torch.cat((x,-x), 1)
        else: 
            raise f"{x.shpe} is invalid in CReLU"
        return F.relu(x)

class DeepFourier(nn.Module):
    def __init__(self):
       super(DeepFourier, self).__init__() 

    def forward(self, x):
        if len(x.shape) == 2:
            x = torch.cat((torch.cos(x), torch.sin(x)), -1)
        elif len(x.shape) == 4:
            x = torch.cat((torch.cos(x), torch.sin(x)), 1)
        else:
            raise f"{x.shpe} is invalid in DeepFourier"
        return x


class MixNormal(nn.Module):
    def __init__(self, 
                    input_type='conv', 
                    input_shape=(3, 32, 32),
                    num_classes=10,
                    cnn_channels=[8, 16, 32, 64],
                    kernel_size=[3, 3, 3, 3],
                    padding=[1, 1, 1, 1],
                    stride=[1, 1, 1, 1],
                    pooling_type=['max', 'max', 'max', 'max'],
                    pooling_kernel=[2, 2, 2, 2],
                    fc_channels=[],
                    activation='relu',
                    layer_norm=False,
                    use_hat=False,
                    hat_config=None,
                    num_tasks=0,
                    class_inc=False):
        super().__init__()
        self.input_type = input_type
        self.input_shape = input_shape
        self.num_classes = num_classes
        self.cnn_channels = cnn_channels
        self.kernel_size = kernel_size
        self.padding = padding
        self.stride = stride
        self.pooling_type = pooling_type
        self.pooling_kernel = pooling_kernel
        self.fc_channels = fc_channels + [num_classes]
        self.activation = activation
        self.layer_norm = layer_norm
        self.use_hat = use_hat
        self.class_inc = class_inc

        self.layer_names = []
        self.act_layers = nn.ModuleList()
        # self.conv_layers = nn.ModuleList()
        # self.fc_layers = nn.ModuleList()
        self.layer_norm_layers = nn.ModuleList()
        self.layers = []
        double = True if activation == 'crelu' or activation == 'deepfourier' else False
        in_channels = self.input_shape[0]
        #breakpoint()
        if input_type == 'conv':
            layers = []
            for i in range(len(self.cnn_channels)):
                layers.append({
                    'kernel_size': kernel_size[i], 'stride': stride[i], 'padding': padding[i], 'filters': cnn_channels[i], 'pool_size': pooling_kernel[i]})

            output_shapes = calculate_network_dims(input_shape[1],
                                   input_shape[2],
                                   input_shape[0],
                                   layers
                                   )
            fc_in = output_shapes[-1]
        else:
            fc_in = math.prod(input_shape)
        
        self.last_filter_output = fc_in
            

        for i, out_dim in enumerate(cnn_channels):
            out_channels = self.cnn_channels[i]
            if activation == 'relu':
                self.act_layers.append(nn.ReLU())
            elif activation == 'gelu':
                self.act_layers.append(nn.GELU())
            elif activation == 'sigmoid':
                self.act_layers.append(nn.Sigmoid())
            elif activation == 'tanh':
                self.act_layers.append(nn.Tanh())
            elif activation == 'crelu':
                self.act_layers.append(CReLU())
            elif activation == 'deepfourier':
                self.act_layers.append(DeepFourier())
            elif activation == 'prelu':
                self.act_layers.append(nn.PReLU())
            else:
                raise ValueError(f"Activation function {activation} not recognized.")
            
            if double:
                assert out_dim % 2 == 0, f"Output dimension {out_dim} must be even for {activation} activation."
                if use_hat:
                    conv_layer = HATConv2d(in_channels=in_channels, out_channels=out_channels//2, kernel_size=kernel_size[i], padding=padding[i], stride=stride[i],
                                           hat_config=hat_config) # type: ignore
                else:
                    conv_layer = nn.Conv2d(in_channels=in_channels, out_channels=out_channels//2, kernel_size=kernel_size[i], padding=padding[i], stride=stride[i])
            else:
                if use_hat:
                    conv_layer = HATConv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size[i], padding=padding[i], stride=stride[i],
                                           hat_config=hat_config) # type: ignore
                else:
                    conv_layer = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size[i], padding=padding[i], stride=stride[i])
            
            setattr(self, f'conv{i+1}', conv_layer)
            in_channels = out_channels
            self.layers.append(getattr(self, f'conv{i+1}'))
            self.layers.append(None)
            self.layer_names.append(f'conv{i+1}')
            #self.conv_layers.append(getattr(self, f'conv{i+1}'))
            
            if layer_norm:
                for shapes in output_shapes: # type: ignore
                    if use_hat:
                        self.layer_norm_layers.append(TaskIndexedLayerNorm(num_tasks=num_tasks))
                    else:
                        self.layer_norm_layers.append(nn.LayerNorm(shapes))

        for i, out_dim in enumerate(self.fc_channels):
            if not i == len(self.fc_channels) - 1:
                if activation == 'relu':
                    self.act_layers.append(nn.ReLU())
                elif activation == 'gelu':
                    self.act_layers.append(nn.GELU())
                elif activation == 'sigmoid':
                    self.act_layers.append(nn.Sigmoid())
                elif activation == 'tanh':
                    self.act_layers.append(nn.Tanh())
                elif activation == 'crelu':
                    self.act_layers.append(CReLU())
                elif activation == 'deepfourier':
                    self.act_layers.append(DeepFourier())
                elif activation == 'prelu':
                    self.act_layers.append(nn.PReLU())
                else:
                    raise ValueError(f"Activation function {activation} not recognized.")
            
            if i == len(self.fc_channels) - 1:
                # if use_hat:
                #     fc_layer = HATLinear(fc_in, out_dim, hat_config=hat_config)
                # else:
                fc_layer = nn.Linear(fc_in, out_dim)
            else:
                if double:
                    assert out_dim % 2 == 0, f"Output dimension {out_dim} must be even for {activation} activation."
                    if use_hat:
                        fc_layer = HATLinear(fc_in, out_dim//2, hat_config=hat_config) # type: ignore
                    else:
                        fc_layer = nn.Linear(fc_in, out_dim//2)
                else:
                    if use_hat:
                        fc_layer = HATLinear(fc_in, out_dim, hat_config=hat_config) # type: ignore
                    else:
                        fc_layer = nn.Linear(fc_in, out_dim)
            
            if self.layer_norm and not i == len(self.fc_channels) - 1:
                if use_hat:
                    self.layer_norm_layers.append(TaskIndexedLayerNorm(num_tasks=num_tasks))
                else:
                    self.layer_norm_layers.append(nn.LayerNorm((out_dim)))
            
            setattr(self, f'fc{i+1}', fc_layer)
            self.layers.append(getattr(self, f'fc{i+1}'))
            self.layer_names.append(f'fc{i+1}')
            #self.fc_layers.append(getattr(self, f'fc{i+1}'))

            if not i == len(self.fc_channels) - 1:
                self.layers.append(None)
            fc_in = out_dim
            
    def forward(self, x):
        #breakpoint()
        self.activations = {}
        self.activations_for_redo = {}
        for i in range(len(self.cnn_channels)):
            #breakpoint()
            conv_layer = getattr(self, f'conv{i+1}')
            x = conv_layer(x)
            if self.layer_norm:
                layer_norm_layer = self.layer_norm_layers[i]
                x = layer_norm_layer(x)
            
            if isinstance(x, HATPayload):
                x = x.forward_by(self.act_layers[i])
            else:
                x = self.act_layers[i](x)
            
            if not i == len(self.cnn_channels) - 1:
                if isinstance(x, HATPayload):
                    x_to_save = x.data
                else:
                    x_to_save = x
                self.activations_for_redo[f'conv{i+1}'] = (x_to_save, 'conv', 'conv')
            else:
                if isinstance(x, HATPayload):
                    x_to_save = x.data
                else:
                    x_to_save = x
                self.activations_for_redo[f'conv{i+1}'] = (x_to_save, 'conv', 'fc')

            if isinstance(x, HATPayload):
                if self.pooling_type[i] == 'max':
                    max_pool = nn.MaxPool2d(kernel_size=self.pooling_kernel[i])
                    x = x.forward_by(max_pool)
                elif self.pooling_type[i] == 'avg':
                    ave_pool = nn.AvgPool2d(kernel_size=self.pooling_kernel[i])
                    x = x.forward_by(ave_pool)
                else:
                    raise ValueError(f"Pooling type {self.pooling_type[i]} not recognized.")
            else:
                if self.pooling_type[i] == 'max':
                    x = F.max_pool2d(x, kernel_size=self.pooling_kernel[i])
                elif self.pooling_type[i] == 'avg':
                    x = F.avg_pool2d(x, kernel_size=self.pooling_kernel[i])
                else:
                    raise ValueError(f"Pooling type {self.pooling_type[i]} not recognized.")
            if isinstance(x, HATPayload):
                x_to_save = x.data
            else:
                x_to_save = x
            self.activations[f'conv{i+1}'] = x_to_save
        #breakpoint()
        if isinstance(x, HATPayload):
            tmp = x.data.view((x.data.size(0), -1))
            x = HATPayload(tmp, task_id=x.task_id, mask_scale=x.mask_scale)
        else:
            x = x.view(x.size(0), -1)

        for i in range(len(self.fc_channels)):
            fc_layer = getattr(self, f'fc{i+1}')
            if i == len(self.fc_channels) - 1:
                if isinstance(x, HATPayload):
                    x = x.data
                    # if self.class_inc: TODO
                    return x
            x = fc_layer(x)
            if i != len(self.fc_channels) - 1:
                if isinstance(x, HATPayload):
                    x = x.forward_by(self.act_layers[i + len(self.cnn_channels)])
                else:
                    x = self.act_layers[i + len(self.cnn_channels)](x)
            if i == len(self.fc_channels) - 1:
                pass
            else:
                if isinstance(x, HATPayload):
                    x_to_save = x.data
                else:
                    x_to_save = x
                self.activations[f'fc{i+1}'] = x_to_save
                self.activations_for_redo[f'fc{i+1}'] = (x_to_save, 'fc', 'fc')
        #breakpoint()
        return x
    
    def get_model_weights_l2_norm(self):
        return utils_l2_norm(self.named_parameters())

    def compute_l1_norm(self):
        return utils_l1_norm(self.named_parameters())
    
    def compute_l2_norm(self):
        return new_utils_l2_norm(self.named_parameters())
    
    def compute_total_params(self):
        # Get the total number of parameters in the neural network
        # NOT including the layer_norm parameters or init params.
        total_params = 0.
        
        for name, param in self.named_parameters():
            if 'layer_norm' not in name and \
                'init_params' not in name and \
                    'original_last_layer_params' not in name:
                    total_params += param.numel()
                    
        return total_params
    
    
