import torch
from torch import nn
import math
from torch.autograd import Function
from torch.nn.utils import weight_norm
import torch.nn.functional as F
from torch.optim.lr_scheduler import StepLR
from .resnet18 import resnet18
import copy

def get_backbone_class(backbone_name):
    """Return the algorithm class with the given name."""
    if backbone_name not in globals():
        raise NotImplementedError("Algorithm not found : {}".format(backbone_name))
    return globals()[backbone_name]

class GradReverse(Function):
    @staticmethod
    def forward(ctx, x, lambd):
        ctx.lambd = lambd
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        return (grad_output * -ctx.lambd), None


def grad_reverse(x, lambd=1.0):
    return GradReverse.apply(x, lambd)


class SSDA_classifier(nn.Module):
    
    def __init__(self, configs, hparams):
        
        super(SSDA_classifier, self).__init__()
        
        self.target_classifier = nn.Sequential(
            nn.Linear(configs.final_out_channels, configs.final_out_channels//2),
            nn.ReLU(),
            nn.Linear(configs.final_out_channels//2, configs.num_classes)
        )
        
        self.optimizer = torch.optim.Adam(
            self.parameters(),
            lr = hparams['learning_rate'],
            weight_decay = hparams['weight_decay']
        )
        self.lr_scheduler = StepLR(self.optimizer, step_size = hparams['step_size'], gamma=hparams['lr_decay'])

    def forward(self, x_in, reverse=False, eta=0.1):
        if reverse:
            x=grad_reverse(x_in, eta)
        predictions = self.target_classifier(x_in)
        return predictions


class classifier(nn.Module):
    def __init__(self, configs):
        super(classifier, self).__init__()
        self.logits = nn.Linear(configs.features_len * configs.final_out_channels, configs.num_classes)
        self.configs = configs

    def forward(self, x, reverse=False, eta=0.1):
        if reverse:
            x = grad_reverse(x, eta)

        predictions = self.logits(x)

        return predictions


#######################################################
################## BACKBONE NETWORKS ##################
#######################################################

########## CNN #############################
class CNN(nn.Module):
    def __init__(self, configs):
        super(CNN, self).__init__()

        self.conv_block1 = nn.Sequential(
            nn.Conv1d(configs.input_channels, configs.mid_channels, kernel_size=configs.kernel_size,
                      stride=configs.stride, bias=False, padding=(configs.kernel_size // 2)),
            nn.BatchNorm1d(configs.mid_channels),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2, stride=2, padding=1),
            nn.Dropout(configs.dropout)
        )

        self.conv_block2 = nn.Sequential(
            nn.Conv1d(configs.mid_channels, configs.mid_channels * 2, kernel_size=8, stride=1, bias=False, padding=4),
            nn.BatchNorm1d(configs.mid_channels * 2),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2, stride=2, padding=1)
        )

        self.conv_block3 = nn.Sequential(
            nn.Conv1d(configs.mid_channels * 2, configs.final_out_channels, kernel_size=8, stride=1, bias=False,
                      padding=4),
            nn.BatchNorm1d(configs.final_out_channels),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2, stride=2, padding=1),
        )

        self.adaptive_pool = nn.AdaptiveAvgPool1d(configs.features_len)

    def forward(self, x_in):
        x = self.conv_block1(x_in)
        x = self.conv_block2(x)
        x = self.conv_block3(x)
        x = self.adaptive_pool(x)

        x_flat = x.reshape(x.shape[0], -1)
        return x_flat


########## TCN #############################
torch.backends.cudnn.benchmark = True  # might be required to fasten TCN


class Chomp1d(nn.Module):
    def __init__(self, chomp_size):
        super(Chomp1d, self).__init__()
        self.chomp_size = chomp_size

    def forward(self, x):
        return x[:, :, :-self.chomp_size].contiguous()


class TCN(nn.Module):
    def __init__(self, configs):
        super(TCN, self).__init__()

        in_channels0 = configs.input_channels
        out_channels0 = configs.tcn_layers[1]
        kernel_size = configs.tcn_kernel_size
        stride = 1
        dilation0 = 1
        padding0 = (kernel_size - 1) * dilation0

        self.net0 = nn.Sequential(
            weight_norm(nn.Conv1d(in_channels0, out_channels0, kernel_size, stride=stride, padding=padding0,
                                  dilation=dilation0)),
            nn.ReLU(),
            weight_norm(nn.Conv1d(out_channels0, out_channels0, kernel_size, stride=stride, padding=padding0,
                                  dilation=dilation0)),
            nn.ReLU(),
        )

        self.downsample0 = nn.Conv1d(in_channels0, out_channels0, 1) if in_channels0 != out_channels0 else None
        self.relu = nn.ReLU()

        in_channels1 = configs.tcn_layers[0]
        out_channels1 = configs.tcn_layers[1]
        dilation1 = 2
        padding1 = (kernel_size - 1) * dilation1
        self.net1 = nn.Sequential(
            nn.Conv1d(in_channels0, out_channels1, kernel_size, stride=stride, padding=padding1, dilation=dilation1),
            nn.ReLU(),
            nn.Conv1d(out_channels1, out_channels1, kernel_size, stride=stride, padding=padding1, dilation=dilation1),
            nn.ReLU(),
        )
        self.downsample1 = nn.Conv1d(out_channels1, out_channels1, 1) if in_channels1 != out_channels1 else None

        self.conv_block1 = nn.Sequential(
            nn.Conv1d(in_channels0, out_channels0, kernel_size=kernel_size, stride=stride, bias=False, padding=padding0,
                      dilation=dilation0),
            Chomp1d(padding0),
            nn.BatchNorm1d(out_channels0),
            nn.ReLU(),

            nn.Conv1d(out_channels0, out_channels0, kernel_size=kernel_size, stride=stride, bias=False,
                      padding=padding0, dilation=dilation0),
            Chomp1d(padding0),
            nn.BatchNorm1d(out_channels0),
            nn.ReLU(),
        )

        self.conv_block2 = nn.Sequential(
            nn.Conv1d(out_channels0, out_channels1, kernel_size=kernel_size, stride=stride, bias=False,
                      padding=padding1, dilation=dilation1),
            Chomp1d(padding1),
            nn.BatchNorm1d(out_channels1),
            nn.ReLU(),

            nn.Conv1d(out_channels1, out_channels1, kernel_size=kernel_size, stride=stride, bias=False,
                      padding=padding1, dilation=dilation1),
            Chomp1d(padding1),
            nn.BatchNorm1d(out_channels1),
            nn.ReLU(),
        )

    def forward(self, inputs):
        """Inputs have to have dimension (N, C_in, L_in)"""
        x0 = self.conv_block1(inputs)
        res0 = inputs if self.downsample0 is None else self.downsample0(inputs)
        out_0 = self.relu(x0 + res0)

        x1 = self.conv_block2(out_0)
        res1 = out_0 if self.downsample1 is None else self.downsample1(out_0)
        out_1 = self.relu(x1 + res1)

        out = out_1[:, :, -1]
        return out



######## RESNET ##############################################
class RESNET18(nn.Module):
    def __init__(self, configs):
        super(RESNET18, self).__init__()
        self.resnet = resnet18(configs)
    def forward(self, x_in):
        x = self.resnet(x_in)
        x_flat = x.reshape(x.shape[0], -1)
        return x_flat

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv1d(inplanes, planes, kernel_size=1, stride=stride,
                               bias=False)
        self.bn1 = nn.BatchNorm1d(planes)

        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = F.relu(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = F.relu(out)

        return out

######## MoSSDA ##############################################
# Momentum Encoder
class MomentumEncoder(nn.Module):
    def __init__(self, backbone, momentum = 0.999):
        super(MomentumEncoder, self).__init__()
        self.backbone = backbone
        self.momentum_encoder = copy.deepcopy(backbone)
        self.momentum = momentum

        for param in self.momentum_encoder.parameters():
            param.requires_grad = False

    @torch.no_grad()
    def update_momentum_encoder(self):
        for param_q, param_k in zip(self.backbone.parameters(), self.momentum_encoder.parameters()):
            param_k.data = param_k.data * self.momentum + param_q.data* (1-self.momentum)

    def forward(self, x):
        return self.backbone(x), self.momentum_encoder(x)

######## DLinear ##############################################
class SeriesDecomp(nn.Module):
    """
    Series decomposition block for DLinear
    """
    def __init__(self, kernel_size):
        super(SeriesDecomp, self).__init__()
        self.moving_avg = MovingAvg(kernel_size, stride=1)

    def forward(self, x):
        moving_mean = self.moving_avg(x)
        res = x - moving_mean
        return res, moving_mean

class MovingAvg(nn.Module):
    """
    Moving average block to highlight the trend of time series
    """
    def __init__(self, kernel_size, stride):
        super(MovingAvg, self).__init__()
        self.kernel_size = kernel_size
        self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)

    def forward(self, x):
        # padding on the both ends of time series
        front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1)
        end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1)
        x = torch.cat([front, x, end], dim=1)
        x = self.avg(x.permute(0, 2, 1))
        x = x.permute(0, 2, 1)
        return x
        
class DLinear(nn.Module):
    def __init__(self, configs):
        super(DLinear, self).__init__()
        self.seq_len = configs.sequence_len
        self.pred_len = configs.features_len  # features_len 사용
        self.individual = getattr(configs, 'individual', False)
        self.channels = configs.input_channels
        
        kernel_size = getattr(configs, 'kernel_size', 25)
        self.decomposition = SeriesDecomp(kernel_size)
        
        if self.individual:
            self.Linear_Seasonal = nn.ModuleList()
            self.Linear_Trend = nn.ModuleList()
            
            for i in range(self.channels):
                self.Linear_Seasonal.append(nn.Linear(self.seq_len, self.pred_len))
                self.Linear_Trend.append(nn.Linear(self.seq_len, self.pred_len))
        else:
            self.Linear_Seasonal = nn.Linear(self.seq_len, self.pred_len)
            self.Linear_Trend = nn.Linear(self.seq_len, self.pred_len)
        
        # 추가: 최종 출력 차원을 맞추기 위한 projection layer
        self.final_projection = nn.Linear(
            self.pred_len * self.channels, 
            configs.features_len * configs.final_out_channels
        )

    def forward(self, x_in):
        # x_in: [Batch, Channel, Length] -> [Batch, Length, Channel]
        x = x_in.permute(0, 2, 1)
        
        seasonal_init, trend_init = self.decomposition(x)
        seasonal_init, trend_init = seasonal_init.permute(0, 2, 1), trend_init.permute(0, 2, 1)
        
        if self.individual:
            seasonal_output = torch.zeros([seasonal_init.size(0), seasonal_init.size(1), self.pred_len],
                                        dtype=seasonal_init.dtype).to(seasonal_init.device)
            trend_output = torch.zeros([trend_init.size(0), trend_init.size(1), self.pred_len],
                                     dtype=trend_init.dtype).to(trend_init.device)
            
            for i in range(self.channels):
                seasonal_output[:, i, :] = self.Linear_Seasonal[i](seasonal_init[:, i, :])
                trend_output[:, i, :] = self.Linear_Trend[i](trend_init[:, i, :])
        else:
            seasonal_output = self.Linear_Seasonal(seasonal_init)
            trend_output = self.Linear_Trend(trend_init)

        x = seasonal_output + trend_output
        # Flatten and project to expected dimension
        x_flat = x.reshape(x.shape[0], -1)
        output = self.final_projection(x_flat)
        return output
        
######## NLinear ##############################################
class NLinear(nn.Module):
    def __init__(self, configs):
        super(NLinear, self).__init__()
        self.seq_len = configs.sequence_len
        self.pred_len = configs.features_len
        self.individual = getattr(configs, 'individual', False)
        self.channels = configs.input_channels
        
        if self.individual:
            self.Linear = nn.ModuleList()
            for i in range(self.channels):
                self.Linear.append(nn.Linear(self.seq_len, self.pred_len))
        else:
            self.Linear = nn.Linear(self.seq_len, self.pred_len)
        
        # 추가: 최종 출력 차원을 맞추기 위한 projection layer
        self.final_projection = nn.Linear(
            self.pred_len * self.channels, 
            configs.features_len * configs.final_out_channels
        )

    def forward(self, x_in):
        # x_in: [Batch, Channel, Length] -> [Batch, Length, Channel]
        x = x_in.permute(0, 2, 1)
        
        seq_last = x[:, -1:, :].detach()
        x = x - seq_last
        
        if self.individual:
            output = torch.zeros([x.size(0), self.pred_len, x.size(2)],
                               dtype=x.dtype).to(x.device)
            for i in range(self.channels):
                output[:, :, i] = self.Linear[i](x[:, :, i])
        else:
            output = self.Linear(x.permute(0, 2, 1)).permute(0, 2, 1)
        
        output = output + seq_last
        # Flatten and project to expected dimension
        output = output.permute(0, 2, 1)
        x_flat = output.reshape(output.shape[0], -1)
        final_output = self.final_projection(x_flat)
        return final_output
        
######## LSTM ##############################################
class LSTM(nn.Module):
    def __init__(self, configs):
        super(LSTM, self).__init__()
        self.input_channels = configs.input_channels
        # self.hidden_dim = getattr(configs, 'hidden_dim', 128)
        self.hidden_dim = configs.features_len*2
        self.num_layers = 2 # getattr(configs, 'num_layers', 2)
        # self.dropout = getattr(configs, 'dropout', 0.1)
        self.dropout = 0.4
        
        self.lstm = nn.LSTM(
            input_size=self.input_channels,
            hidden_size=self.hidden_dim,
            num_layers=self.num_layers,
            dropout=self.dropout if self.num_layers > 1 else 0,
            batch_first=True,
            bidirectional=False
        )
        
        # 최종 출력 차원을 맞추기 위한 projection layer
        self.fc = nn.Linear(
            self.hidden_dim, 
            configs.features_len * configs.final_out_channels
        )
        
    def forward(self, x_in):
        # x_in: [Batch, Channel, Length] -> [Batch, Length, Channel]
        x = x_in.permute(0, 2, 1)
        
        # LSTM forward
        lstm_out, (hidden, cell) = self.lstm(x)
        
        # Use the last hidden state
        last_hidden = hidden[-1]  # [Batch, hidden_dim]
        
        # Project to expected dimension
        features = self.fc(last_hidden)
        return features

######## GRU ##############################################
class GRU(nn.Module):
    def __init__(self, configs):
        super(GRU, self).__init__()
        self.input_channels = configs.input_channels
        # self.hidden_dim = getattr(configs, 'hidden_dim', 128)
        self.hidden_dim = configs.features_len*2
        self.num_layers = 2 # getattr(configs, 'num_layers', 2)
        # self.dropout = getattr(configs, 'dropout', 0.1)
        self.dropout = 0.4
        
        self.gru = nn.GRU(
            input_size=self.input_channels,
            hidden_size=self.hidden_dim,
            num_layers=self.num_layers,
            dropout=self.dropout if self.num_layers > 1 else 0,
            batch_first=True,
            bidirectional=False
        )
        
        # 최종 출력 차원을 맞추기 위한 projection layer
        self.fc = nn.Linear(
            self.hidden_dim, 
            configs.features_len * configs.final_out_channels
        )
        
    def forward(self, x_in):
        # x_in: [Batch, Channel, Length] -> [Batch, Length, Channel]
        x = x_in.permute(0, 2, 1)
        
        # GRU forward
        gru_out, hidden = self.gru(x)
        
        # Use the last hidden state
        last_hidden = hidden[-1]  # [Batch, hidden_dim]
        
        # Project to expected dimension
        features = self.fc(last_hidden)
        return features
