import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from layers.Conv_layers import Conv1dSubampling, Conv1dUpsampling, Conv1dSubampling_new
from layers.TGPT_layers import RetNetBlock
from layers.Embed import PatchEmbedding, ValueEmbedding
from typing import List, Optional, Tuple, Union
from layers.snippets import get_gpu_memory_usage, SigmoidRange
import gc
from layers.RevIN import RevIN
from sklearn.metrics import roc_auc_score, average_precision_score


class TGP(nn.Module):
    '''
    Temporal Generative Pre-training leverages recurrent-form transformer architecture for multi-variate time series
    '''
    def __init__(self, configs, head_type='pretrain', num_classes=2, num_var=1, regr_dim=1):
        super(TGP, self).__init__()

        # load parameters
        self.n_layers = configs.d_layers
        self.n_heads = configs.n_heads
        self.d_model = configs.d_model
        self.ffn_size = configs.d_ff
        self.dropout = configs.dropout

        self.c_in = configs.c_in
        # self.seq_len = configs.seq_len
        # self.pred_len = configs.pred_len

        self.qk_dim = configs.qk_dim
        self.v_dim = configs.v_dim if configs.v_dim else self.qk_dim

        # the start token for shifted right
        self.sos = torch.nn.Parameter(torch.zeros(self.d_model))
        nn.init.normal_(self.sos)

        self.revin_layer = RevIN(self.c_in)

        # Integration of ConvSubampling for embedding
        # self.conv_subsampling = Conv1dSubampling(in_channels=self.c_in, out_channels=self.d_model, reduce_time_layers=2)
        self.conv_subsampling = Conv1dSubampling_new(in_channels=self.c_in, out_channels=self.d_model, reduce_time_layers=2)

        self.input_projection = nn.Sequential(
            nn.Linear(self.d_model, self.d_model),
            nn.Dropout(p=self.dropout)
        )

        # the stacked decoder layer
        self.blocks = nn.ModuleList([RetNetBlock(configs) for _ in range(self.n_layers)])

        # Add ConvUpampling for upsampling, note that the argument intermediate_channels is not used
        self.conv_upsampling = Conv1dUpsampling(hidden_dim=self.d_model, reduce_time_layers=2)

        # output layer
        self.ln_f = nn.LayerNorm(self.d_model)  # Layer Normalization
        self.head_type = head_type
        self.num_classes = num_classes
        if self.head_type == "pretrain":
            # right now we suppose the token is [batch_size x seq_len x c_in]
            self.head = PretrainHead(self.d_model, self.c_in)
        elif self.head_type == "forecasting":
            self.head = ForecastingHead(self.d_model, self.c_in)
        elif self.head_type == "classification":
            self.head = ClassificationHead(self.d_model, num_classes)
        elif self.head_type == "regression":
            self.head = RegressionHead(self.d_model, regr_dim)
        else:
            raise ValueError("Invalid head_type provided.")
        # self.new_token_projection = nn.Linear(self.d_model, self.c_in)  # For next token (c_in) prediction (before upsampling)
        # self.token_projection = nn.Linear(self.d_model, self.d_model)  # For next token (d_model) prediction (before upsampling)

        self.gradient_checkpointing = configs.use_grad_ckp

    def forward(self,
                X, y,
                retention_mask: Optional[torch.Tensor] = None,
                past_key_values: Optional[List[torch.FloatTensor]] = None,
                forward_impl: Optional[str] = 'chunkwise', # chunkwise
                chunk_size: Optional[int] = None,
                sequence_offset: Optional[int] = 0,
                output_retentions: Optional[bool] = None,
                output_hidden_states: Optional[bool] = None,
                ):
        # Use ConvSubsampling as tokenizer, input_project as embedding layer
        X, X_tokens = self.conv_subsampling(X)
        hidden_states = self.input_projection(X)
        batch_size, seq_len, dim = X.shape

        # Add the SOS token to the input sequence
        sos_token = self.sos.unsqueeze(0).repeat(batch_size, 1, 1)  # Shape [batch_size, 1, d_model]
        hidden_states = torch.cat([sos_token, hidden_states[:, :-1, :]], dim=1)  # Shift right and drop the last value to maintain original length

        if retention_mask is None: # what is the usage of rentention mask
            # not sure whether we need to mask the first token (SOS token)
            retention_mask = torch.ones((batch_size, seq_len), dtype=torch.bool, device=X.device) # batch_size x token_num

        all_hidden_states = () if output_hidden_states else None
        all_retentions = () if output_retentions else None
        present_key_values = ()  # To store current key-value pairs
        for i, block in enumerate(self.blocks):
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            past_key_value = past_key_values[i] if past_key_values is not None else None

            # Use gradient checkpointing for the forward pass of the block
            if self.gradient_checkpointing and self.training:
                def custom_forward(*inputs):
                    return block(*inputs, sequence_offset, chunk_size, output_retentions)

                block_outputs = torch.utils.checkpoint.checkpoint(
                    custom_forward,
                    hidden_states,
                    retention_mask,
                    forward_impl,
                    past_key_value,
                )
            else:
                block_outputs = block(hidden_states,
                                      retention_mask=retention_mask,
                                      forward_impl=forward_impl,
                                      past_key_value=past_key_value,
                                      sequence_offset=sequence_offset,
                                      chunk_size=chunk_size,
                                      output_retentions=output_retentions)

            hidden_states = block_outputs[0]
            present_key_values += (block_outputs[1],)

            torch.cuda.empty_cache()
            gc.collect()

            if output_retentions:
                all_retentions += (block_outputs[2],)
        # add hidden states from the last decoder layer
        if output_hidden_states:
            all_hidden_states += (hidden_states)

        # Apply the custom head on the hidden states for output
        X = self.ln_f(hidden_states)
        outputs = self.head(X)

        if self.head_type == 'pretrain':
            return self.compute_pretrain_loss(outputs, X_tokens) # return pre-trained loss
        elif self.head_type == 'forecasting':
            return self.compute_forecasting_loss(outputs, y) # return classification loss
        elif self.head_type == 'classification':
            return self.compute_classify_loss(outputs, y) # return classification loss
        elif self.head_type == 'regression':
            return self.compute_regr_loss(outputs, y) # return regression loss

    def compute_pretrain_loss(self, token_predictions, token_targets):
        """
        Compute the loss of the pre-training task (next token prediction)
        """
        self.mse_loss = nn.MSELoss()
        token_loss = self.mse_loss(token_predictions, token_targets)
        return token_loss

    def compute_regr_loss(self, regr_predictions, regr_targets):
        """
        Compute the loss of the regression task
        """
        self.mse_loss = nn.MSELoss()
        self.mae_loss = nn.L1Loss()

        # Ensure regr_targets is a float tensor and has the same shape as regr_predictions
        regr_targets = regr_targets.float().view_as(regr_predictions)

        regr_loss = self.mse_loss(regr_predictions, regr_targets)
        mae = F.l1_loss(regr_predictions, regr_targets)  # L1 loss is equivalent to MAE
        return regr_loss, mae

    def compute_classify_loss(self, cls_logits, cls_targets):
        """
        Compute the loss of classification task
        """
        # compute cross entropy loss
        self.cross_entropy_loss = nn.CrossEntropyLoss()
        cls_loss = self.cross_entropy_loss(cls_logits, cls_targets)

        # make prediction
        probs = F.softmax(cls_logits, dim=1)
        predicted = torch.argmax(probs, dim=1)

        # compute Accuracy
        correct = (predicted == cls_targets).float()
        accuracy = correct.mean() * 100.0

        # Precision, Recall, and F1 Score
        TP = (predicted * cls_targets).sum().float()
        FP = (predicted * (1 - cls_targets)).sum().float()
        FN = ((1 - predicted) * cls_targets).sum().float()

        precision = TP / (TP + FP + 1e-7)
        recall = TP / (TP + FN + 1e-7)
        f1_score = 2 * (precision * recall) / (precision + recall + 1e-7)

        # Assuming binary classification for Macro F1 Score, AUROC, AUPRC
        macro_f1_score = f1_score  # Placeholder for binary classification

        return {
            'cls_loss': cls_loss,
            'accuracy': accuracy.item(),
            'precision': precision.item(),
            'recall': recall.item(),
            'f1_score': f1_score.item(),
            'macro_f1_score': macro_f1_score.item(),
        }

class PretrainHead(nn.Module):
    def __init__(self, d_model, c_in):
        super(PretrainHead, self).__init__()
        self.head = nn.Linear(d_model, c_in)  # Predicting the next token (before upsampling)
        # self.head = nn.Linear(self.d_model, self.d_model)
        # token_output = self.new_token_projection(hidden_states) # here the channel of token_output is the same as c_in
        # token_output = self.token_projection(hidden_states) # here the channel of token_output is the same as hidden_states

    def forward(self, x):
        """
        x: tensor [batch_size x seq_len x d_model]
        output: tensor [batch_size x seq_len x c_in / d_model]
        """
        return self.head(x)


class ClassificationHead(nn.Module):
    def __init__(self, d_model, num_classes):
        super(ClassificationHead, self).__init__()
        self.clf_layer = nn.Linear(d_model, num_classes)

    def forward(self, x):
        """
        x: tensor [batch_size x seq_len x d_model]
        output: tensor [batch_size x num_classes]
        """
        x = x.mean(dim=1)         # Average pool over the hidden states
        logits = self.clf_layer(x)
        return logits

class RegressionHead(nn.Module):
    def __init__(self, d_model, output_dim, y_range=None):
        super().__init__()
        self.y_range = y_range
        self.regr_layer = nn.Linear(d_model, output_dim)

    def forward(self, x):
        """
        x: [bs x nvars x d_model x num_patch]
        output: [bs x output_dim]
        """
        x = x.mean(dim=1)         # Average pool over the sequence dimension
        y = self.regr_layer(x)
        if self.y_range: y = SigmoidRange(*self.y_range)(y)

        return y


class ForecastingHead(nn.Module):
    def __init__(self, d_model, c_in, red_factor=4):
        super(ForecastingHead, self).__init__()
        self.d_model = d_model
        self.c_in = c_in
        self.red_factor = red_factor # default is 4 since conv-subsampling reduce sequence lengh to 1/4

        self.token_projection = nn.Linear(self.d_model, self.d_model)
        self.time_projection = nn.Linear(self.d_model, red_factor*self.c_in)
    def forward(self, hidden_states):
        """
        hidden_states: tensor of shape [batch_size, seq_len, d_model]
        """
        x = hidden_states.mean(dim=1)  # Average pool over the hidden states
        next_token = self.token_projection(x)
        compressed_next_timesteps = self.time_projection(next_token)
        # Reshaping the compressed next timesteps into the required form [red_factor, c_in]
        next_timesteps = compressed_next_timesteps.view(-1, self.red_factor, self.c_in)

        return next_token, next_timesteps
