# coding=utf-8

""" ReservoirTransformer model configuration"""
""" Author: Md Kowsher"""
from collections import OrderedDict
from typing import Mapping

from transformers import PretrainedConfig
from transformers.onnx import OnnxConfig
from transformers.utils import logging


logger = logging.get_logger(__name__)




class ReservoirTConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`ReservoirTModel`]. It is used to
    instantiate a ReservoirTTimeSeries model according to the specified arguments, defining the model architecture.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.


    Args:

        hidden_size (`int`, *optional*, defaults to 16):
            Dimensionality of the encoder layers and the pooler layer.
        num_hidden_layers (`int`, *optional*, defaults to 4):
            Number of hidden layers in the Transformer encoder.
        num_attention_heads (`int`, *optional*, defaults to 4):
            Number of attention heads for each attention layer in the Transformer encoder.
        intermediate_size (`int`, *optional*, defaults to 64):
            Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
        hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
            The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
            `"relu"`, `"silu"` and `"gelu_new"` are supported.
        hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
            The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.


        max_sequence_length (`int`, *optional*, defaults to 500):
            The maximum sequence lenght.
        sequence_length (`int`, *optional*, defaults to 12):
            The  sequence lenght of input which is look-back windows to capture the previous history.
        output_size (`int`, *optional*, defaults to None):
            The output dimension of prediction value. In general for mulitvariate-time series, we use all feature to predict.
        re_output_size (`int`, *optional*, defaults to 4):
            The reservoir output dimension.
        pred_len (`int`, *optional*, defaults to 720):
            The multivaraite horizons to predict.


        num_reservoirs (`int`, *optional*, defaults to 10):
            The reservoirs for ensembelling (group reservoir)
        reservoir_size (`List[int]`, *optional*, defaults to [30, 15, 20, 25, 30, 35, 40, 45, 50, 50]):
            The  reservoir sizes of group reservoir
        spectral_radius (`List[float]`, *optional*, defaults to [0.6, 0.8, 0.55, 0.6, 0.5, 0.4, 0.3, 0.2, 0.81, 0.05]):
            The spectral radius of each reservoir in group reservoir
        sparsity (`List[float]`, *optional*, defaults to [0.6, 0.55, 0.5, 0.45, 0.4, 0.35, 0.3, 0.25, 0.2, 0.15]):
            The sparsity rate in each reservoir in group reservoir
        leaky (`List[float]`, *optional*, defaults to [0.3, 0.31, 0.32, 0.33, 0.34, 0.35, 0.36, 0.37, 0.38, 0.39]):
            The leaky rate in each reservoir in group reservoir




        attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
            The dropout ratio for the attention probabilities.
        max_position_embeddings (`int`, *optional*, defaults to 512):
            The maximum sequence length that this model might ever be used with. Typically set this to something large
            just in case (e.g., 512 or 1024 or 2048).
        type_vocab_size (`int`, *optional*, defaults to 2):
            The vocabulary size [mask or non_mask here] of the `token_type_ids` passed when calling [`ReservoirTModel`] .
        initializer_range (`float`, *optional*, defaults to 0.02):
            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
        layer_norm_eps (`float`, *optional*, defaults to 1e-12):
            The epsilon used by the layer normalization layers.
        position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
            Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
            positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
            [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155).
            For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
            with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658).
        is_decoder (`bool`, *optional*, defaults to `False`):
            Whether the model is used as a decoder or not. If `False`, the model is used as an encoder.
        use_cache (`bool`, *optional*, defaults to `True`):
            Whether or not the model should return the last key/values attentions (not used by all models). Only
            relevant if `config.is_decoder=True`.
        decoder_dropout (`float`, *optional*):
            The dropout ratio for the classification or regression head.
        problem_type ('str', *optional*):
            Type of problem such as 'regression', 'single_label_classification', 'multi_label_classification'


    Examples:

    ```python
    >>> from configuration import ReservoirTConfig

    >>> # Initializing a trnasformer style configuration
    >>> configuration = ReservoirTConfig()

    >>> # Initializing a model (with random weights) from trnasformer style configuration
    >>> model = ReservoirTTimeSeries(config = configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```"""
    model_type = "ReservoirTransformer"

    def __init__(
        self,
        hidden_size=8,
        output_size=None,
        re_output_size=4,
        num_hidden_layers=4,
        pred_len=720,
        num_attention_heads=4,
        intermediate_size=64,
        hidden_act="gelu",
        hidden_dropout_prob=0.1,
        attention_probs_dropout_prob=0.1,
        max_sequence_length=500,
        sequence_length=12,
        type_vocab_size=2,
        num_reservoirs=10,
        reservoir_size = [30, 15, 20, 25, 30, 35, 40, 45, 50, 50],
        spectral_radius = [0.6, 0.8, 0.55, 0.6, 0.5, 0.4, 0.3, 0.2, 0.81, 0.05],
        sparsity = [0.6, 0.55, 0.5, 0.45, 0.4, 0.35, 0.3, 0.25, 0.2, 0.15],
        leaky = [0.3, 0.31, 0.32, 0.33, 0.34, 0.35, 0.36, 0.37, 0.38, 0.39],
        initializer_range=0.02,
        layer_norm_eps=1e-12,
        pad_token_id=0,
        position_embedding_type="absolute",
        use_cache=True,
        decoder_dropout=None,
        problem_type=None,
        soft_border=8,
        batch=64,
        train_size=0.7,
        val_size=0.1,
        test_size=0.2,
        scaling=True,



        #regressor_dropout=None,

        **kwargs,
    ):
        super().__init__(pad_token_id=pad_token_id, **kwargs)

        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.hidden_act = hidden_act
        self.intermediate_size = intermediate_size
        self.hidden_dropout_prob = hidden_dropout_prob
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
        self.type_vocab_size = type_vocab_size
        self.initializer_range = initializer_range
        self.layer_norm_eps = layer_norm_eps
        self.position_embedding_type = position_embedding_type
        self.use_cache = use_cache
        self.decoder_dropout = decoder_dropout
        self.output_size = output_size
        self.re_output_size = re_output_size
        self.pred_len = pred_len
        self.max_sequence_length = max_sequence_length
        self.sequence_length = sequence_length
        self.problem_type = problem_type
        self.num_reservoirs = num_reservoirs
        self.spectral_radius = spectral_radius
        self.sparsity = sparsity
        self.reservoir_size = reservoir_size
        self.leaky = leaky
        self.soft_border=soft_border
        self.batch=batch
        self.train_size=train_size
        self.val_size=val_size
        self.test_size=test_size
        self.scaling=scaling


class ReservoirTOnnxConfig(OnnxConfig):
    @property
    def inputs(self) -> Mapping[str, Mapping[int, str]]:
        if self.task == "multiple-choice":
            dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
        else:
            dynamic_axis = {0: "batch", 1: "sequence"}
        return OrderedDict(
            [
                ("input_ids", dynamic_axis),
                ("attention_mask", dynamic_axis),
                ("token_type_ids", dynamic_axis),
            ]
        )


class TPGNConfig:
    def __init__(self):
        # 数据集配置 (ETTh1)
        self.freq = 'h'  # 小时级数据 (h=hourly)
        self.enc_in = 140  # 输入特征数 (HUFL, HULL, MUFL, MULL, LUFL, LULL, OT)
        self.c_out = 140   # 输出特征数 (与输入维度相同)

        # 模型结构配置
        self.d_model = 128    # 隐藏层维度 (根据GPU显存调整)
        self.norm = 0         # 启用归一化
        self.TPGN_period = 24 # 周期长度 (24小时/天)

        # 序列长度配置
        self.seq_len = 168    # 输入序列长度 (7天 * 24小时)
        self.pred_len = 720    # 预测序列长度 (1天)

        # 训练配置 (可根据需要添加)
        self.batch_size = 32
        self.learning_rate = 1e-3
        self.num_epochs = 100
        self.dropout = 0.1
config_TPGN = TPGNConfig()


import torch
import torch.nn as nn
import math

class PGN_2d(nn.Module):
    def __init__(self, seq_R, freq, c_in, c_out, windows_size):
        super(PGN_2d, self).__init__()

        self.seq_R = seq_R
        self.freq = freq
        self.c_out = c_out
        self.windows_size = windows_size

        if freq == 't':
            dim_time = 5
        elif freq == 'h':
            dim_time = 4
        if freq == 'd':
            dim_time = 3

        self.hidden_MLP = nn.Conv1d(
            in_channels = c_in * (1 + dim_time),
            out_channels = c_in * c_out,
            kernel_size = windows_size,
            stride = 1, groups = c_in)

        self.gate = nn.Conv1d(
            in_channels = c_in * (1 + dim_time + c_out),
            out_channels = c_in * 2 * c_out,
            kernel_size = 1,
            stride = 1, groups = c_in)

        self.fc = nn.Conv1d(
            in_channels = c_in * c_out,
            out_channels = c_in * c_out,
            kernel_size = seq_R,
            stride = 1, groups = c_in)

    def deal(self, x, x_mark):
        B, R, C, c_in, _ = x.shape
        c_time = x_mark.shape[-1]
        x_input = torch.cat([x, x_mark], dim=-1)
        x_supply = torch.zeros(B, self.windows_size, C,
            c_in, (1 + c_time)).to(x.device)
        x_all = torch.cat([x_supply, x_input], dim=1).permute(0, 2, 1, 3, 4)
        x_all = x_all.reshape(B * C, R + self.windows_size,
            c_in * (1 + c_time)).permute(0, 2, 1)
        x_all_out = self.hidden_MLP(x_all[:, :, :-1]).reshape(
            B, C, c_in, self.c_out, R).permute(0, 4, 1, 2, 3)
        return x_all_out

    def gated_unit(self, x, x_mark, hid):
        x = torch.cat([x, x_mark, hid], dim=-1)
        B, R, C, c_in, c_all = x.shape
        x = x.reshape(B * R * C, c_in * c_all, 1)
        x_embed = self.gate(x).reshape(B, R, C, c_in, -1)
        sigmod_gate, tanh_gate = torch.split(x_embed, self.c_out, dim = -1)
        sigmod_gate = torch.sigmoid(sigmod_gate)
        tanh_gate = torch.tanh(tanh_gate)
        hid = hid * sigmod_gate + (1 - sigmod_gate) * tanh_gate
        return hid

    def forward(self, x, x_mark):
        B, R, C, c_in, _ = x.shape
        c_time = x_mark.shape[-1]
        out = self.deal(x, x_mark)
        out = self.gated_unit(x, x_mark, out)
        out = self.fc(out.permute(0, 2, 3, 4, 1).reshape(
            B * C, c_in * self.c_out, R)).reshape(B, C, c_in, self.c_out)
        return out

class short_term_deal(nn.Module):
    def __init__(self, seq_R, freq, c_in, c_out, period):
        super(short_term_deal, self).__init__()

        self.seq_R = seq_R
        self.freq = freq
        self.c_out = c_out
        self.period = period

        if freq == 't':
            dim_time = 5
        elif freq == 'h':
            dim_time = 4
        if freq == 'd':
            dim_time = 3

        self.fc_row = nn.Conv1d(
            in_channels = c_in * (1 + dim_time),
            out_channels = c_in * c_out,
            kernel_size = period,
            stride = 1, groups = c_in)

        self.fc_col = nn.Conv1d(
            in_channels = c_in * c_out,
            out_channels = c_in * c_out,
            kernel_size = seq_R,
            stride = 1, groups = c_in)

    def forward(self, x, x_mark):
        B, R, C, c_in, _ = x.shape
        c_time = x_mark.shape[-1]
        x_input = torch.cat([x, x_mark], dim=-1)
        out = self.fc_row(x_input.permute(0, 1, 3, 4, 2).reshape(
            B * R, c_in * (1 + c_time), C)).reshape(B, R, c_in * self.c_out)
        out = self.fc_col(out.permute(0, 2, 1)).reshape(
            B, c_in, 1, self.c_out).repeat(1, 1, self.period, 1)
        return out.permute(0, 2, 1, 3)

class TPGN(nn.Module):
    def __init__(self, seq_R, freq, c_in, c_out, windows_size,
            period, pred_R, need_short=1):
        super(TPGN, self).__init__()

        self.freq = freq
        self.c_in = c_in
        self.c_out = c_out
        self.windows_size = windows_size
        self.pred_R = pred_R
        self.need_short = need_short

        self.LNN_dim = PGN_2d(seq_R, freq, c_in, c_out, windows_size)

        if self.need_short:
            self.s_t_p_e = short_term_deal(seq_R,
                freq, c_in, c_out, period)

            self.fc = nn.Conv1d(
                in_channels = c_in * 2 * c_out,
                out_channels = c_in * pred_R,
                kernel_size = 1,
                stride = 1, groups = c_in)
        else:
            self.fc = nn.Conv1d(
                in_channels = c_in * c_out,
                out_channels = c_in * pred_R,
                kernel_size = 1,
                stride = 1, groups = c_in)

    def forward(self, x, x_mark):
        B, R, C, c_in = x.shape
        c_time = x_mark.shape[-1]
        x = x.unsqueeze(-1)
        x_mark = x_mark.unsqueeze(-2).repeat(1, 1, 1, c_in, 1)
        out_long_term = self.LNN_dim(x, x_mark)
        if self.need_short:
            out_short_term = self.s_t_p_e(x, x_mark)
            out_all = torch.cat([out_short_term, out_long_term],
                dim=-1).reshape(B * C, c_in * 2 * self.c_out, 1)
        else:
            out_all = out_long_term.reshape(
                B * C, c_in * self.c_out, 1)

        out_all = self.fc(out_all).reshape(B, C, c_in, self.pred_R).permute(
            0, 3, 1, 2).reshape(B, -1, c_in)
        return out_all

class TPGN_fl(nn.Module):
    def __init__(self, configs):
        super(TPGN_fl, self).__init__()

        self.configs = configs

        self.freq = configs.freq
        self.period = configs.TPGN_period
        self.seq_len = configs.seq_len
        self.pred_len = configs.pred_len
        self.seq_R = int(math.ceil(self.seq_len/self.period))
        self.pred_R = int(math.ceil(self.pred_len/self.period))

        self.c_in = configs.enc_in
        self.c_out = configs.c_out

        self.d_model = configs.d_model
        self.norm = configs.norm

        self.TPGN = TPGN(self.seq_R, self.freq, self.c_in,
            self.d_model, self.seq_R-1, self.period, self.pred_R)

    def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec,
                enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None):

        # x: [Batch, Input length, Channel]
        B, L, c_in = x_enc.shape
        c_time = x_mark_enc.shape[-1]

        if self.norm == 1:
            means = x_enc.mean(1, keepdim=True).detach()
            x_enc = x_enc - means
            stdev = torch.sqrt(torch.var(x_enc, dim=1,
                keepdim=True, unbiased=False) + 1e-5)
            x_enc /= stdev

        x_enc = x_enc.reshape(B, self.seq_R, self.period, c_in)
        x_mark_enc = x_mark_enc.reshape(B, self.seq_R, self.period, c_time)

        output = self.TPGN(x_enc, x_mark_enc)

        if self.norm == 1:
            output = output * (stdev[:, 0, :].unsqueeze(1).repeat(
                1, self.pred_len, 1))
            output = output + (means[:, 0, :].unsqueeze(1).repeat(
                1, self.pred_len, 1))

        return output


import os

from typing import List, Optional, Tuple, Union
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from transformers import PatchTSTModel, PatchTSTConfig, TrainingArguments, EarlyStoppingCallback, Trainer, PatchTSTForPrediction, PatchTSMixerConfig, PatchTSMixerForPrediction
#from reservoir_computing.modules import RC_model
#from configuration import ReservoirTConfig
from tqdm import tqdm
from datasets import Dataset
import wandb
from transformers.modeling_outputs import (
    BaseModelOutputWithPastAndCrossAttentions,
    BaseModelOutputWithPoolingAndCrossAttentions,
    CausalLMOutputWithCrossAttentions,
    MaskedLMOutput,
    MultipleChoiceModelOutput,
    NextSentencePredictorOutput,
    QuestionAnsweringModelOutput,
    SequenceClassifierOutput,
    TokenClassifierOutput,
)



wandb.init(
    # set the wandb project where this run will be logged
    project="my-awesome-project",

    # track hyperparameters and run metadata
    config={
    "learning_rate": 0.002,
    "epochs": 100,
    }
)

configuration = ReservoirTConfig()

configuration.output_size=140
configuration.re_output_size=21
configuration.max_sequence_length=1000
configuration.sequence_length=336
configuration.pred_len=720
configuration.hidden_size=7
configuration.num_attention_heads=7
configuration.hidden_dropout_prob=0.1
configuration.num_hidden_layers=16
configuration.num_reservoirs = 10
configuration.intermediate_size=128
#configuration.reservoir_size = [10,11]
#configuration.spectral_radius = [0.6, 0.8]
#configuration.sparsity = [0.6, 0.55]
#configuration.leaky = [0.4, 0.41]
configuration.reservoir_size = [10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
configuration.spectral_radius = [0.6, 0.8, 0.55, 0.6, 0.5, 0.4, 0.3, 0.2, 0.81, 0.05]
configuration.sparsity = [0.6, 0.55, 0.5, 0.45, 0.4, 0.35, 0.3, 0.25, 0.2, 0.15]
configuration.leaky = [0.4, 0.41, 0.42, 0.43, 0.44, 0.45, 0.46, 0.47, 0.48, 0.49]
# configuration.reservoir_size = 1000
configuration.activation_function = ["tanh","sigmoid","relu","tanh","sigmoid","relu","tanh","sigmoid","relu","tanh"]
configuration.attention_probs_dropout_prob=0.0
configuration.batch_size = 32
configuration.embedding_size = 140
configuration.embedding_type = 2
configuration.num_heads = 7


class TimeSeriesEmbedding(nn.Module):
    def __init__(self,config):

        super(TimeSeriesEmbedding, self).__init__()
        self.hidden_size = config.hidden_size
        self.embedding_size = config.embedding_size
        self.embedding_type = config.embedding_type
        self.sequence_length = config.sequence_length
        self.feature_as_token_each_feature_emb_size = int(config.embedding_size/config.hidden_size)
        self.query = nn.Linear(self.hidden_size,self.embedding_size)
        self.key = nn.Linear(self.hidden_size,self.embedding_size)
        self.value = nn.Linear(self.hidden_size,self.embedding_size)
        self.multihead_attn = nn.MultiheadAttention(self.embedding_size, num_heads=config.num_heads,batch_first=True)
        self.batch_size = config.batch_size
        self.feature_as_token_weights = nn.ModuleList([nn.Linear(1, self.feature_as_token_each_feature_emb_size) for _ in range(self.hidden_size)])



    def forward(self,input_ids,key_values_input_ids = None):
        input_ids = input_ids.float()
        if self.embedding_type == 1:
            query = self.query(input_ids)
            if key_values_input_ids is not None:
                key = self.key(key_values_input_ids)
                value = self.key(key_values_input_ids)
            else:
                key = self.key(input_ids)
                value = self.key(input_ids)
            attn_output, attn_weights = self.multihead_attn(query, key, value)

            return attn_output

        if self.embedding_type == 2:
            fl_inputs_embeds_list = []
            for i in range(self.hidden_size):
                input_features_seq_scale = input_ids[:,:,i] #(sample_size, time_length)
                input_features_seq = input_features_seq_scale.unsqueeze(-1)
                input_featrues_embeds = self.feature_as_token_weights[i](input_features_seq) #(sample_size,time_length,num_features)
                fl_inputs_embeds_list.append(input_featrues_embeds * input_features_seq_scale.unsqueeze(-1))  # Broadcasting to match shape



            fl_input_embeds = torch.cat(fl_inputs_embeds_list, dim=-1)  # Shape: (batch_size, seq_length, total_embed_dim)
            return fl_input_embeds


class DeepReservoirNet(nn.Module):
    def __init__(self, config, reservoir_size=1000, spectral_radius=0.9, leaky=0.3,activation_function="tanh", sparsity=0.5):
        super(DeepReservoirNet, self).__init__()

        self.input_size = config.sequence_length
        self.reservoir_size = reservoir_size
        self.output_size = config.re_output_size
        self.spectral_radius = spectral_radius
        self.leaky = leaky

        self.W_in = nn.Linear(self.input_size, reservoir_size, bias=False).float()
        self.W_in.weight.requires_grad = False
        self.W_res = nn.Linear(reservoir_size, reservoir_size, bias=False).float()
        self.W_res.weight.requires_grad = False
        #self.W_out = nn.Linear(reservoir_size, self.output_size).float()
        #self.W_out.weight.requires_grad = False
        self.W_leaky_d = nn.Linear(reservoir_size, reservoir_size).float()
        self.W_leaky_u = nn.Linear(reservoir_size, reservoir_size).float()
        self.W_leaky_d.weight.requires_grad = False
        self.W_leaky_u.weight.requires_grad = False
        self.res_state = torch.zeros(1, reservoir_size).float()
        self.tanh = nn.Tanh()
        self.act= nn.Tanh()
        if activation_function == "tanh":
            self.act = nn.Tanh()
        elif activation_function == "relu":
            self.act = nn.ReLU()
        elif activation_function == "sigmoid":
            self.act = nn.Sigmoid()

        self.W_res_norm = self.compute_spectral_radius(sparsity)
        self.self_attention = nn.MultiheadAttention(self.output_size, config.num_attention_heads, dropout=0.2)


    def compute_spectral_radius(self, sparsity=0.5):
        with torch.no_grad():
            self.W_res.weight.data = torch.randn(self.reservoir_size, self.reservoir_size)
            # set a fraction of the entries to zero
            num_zeros = int(sparsity * self.reservoir_size ** 2)
            idxs = torch.randperm(self.reservoir_size ** 2)[:num_zeros]
            self.W_res.weight.data.view(-1)[idxs] = 0

            eigenvals = torch.linalg.eigvals(self.W_res.weight)
            radius = torch.max(torch.abs(eigenvals))
            self.W_res.weight.data /= radius
        return radius
    def forward(self, input_data, res_state):
        #print()
        # Compute reservoir state
        outputs = []
        #if res_state == None:
        #   res_state = self.res_state.clone()

        batch_size = input_data.shape[0]
        input_data = input_data.permute(0, 2, 1)
        for t in range(batch_size):

            i_data = input_data[t]

            #print("i_data", i_data.shape)
            input_proj = self.W_in(i_data.float())

            res_proj = self.W_res(res_state)

            # print('res_state', res_state.shape)
            #print('input_proj', input_proj.shape)
            #print('res_proj', res_proj.shape)
            if self.act == "relu":
                middle_proj = input_proj + res_proj
                middle_proj = self.layer_norm(middle_proj)
                res_state = self.tanh(self.W_leaky_d(torch.clip(self.act(middle_proj),min=-1,max=1)) + self.W_leaky_u(res_state))
            elif self.act == "sigmoid":
                middle_proj = input_proj + res_proj
                res_state = self.tanh(self.W_leaky_d(2*self.act(middle_proj)-1)+self.W_leaky_u(res_state))
            else:
                res_state = self.tanh(self.W_leaky_d(self.act(input_proj + res_proj))+self.W_leaky_u(res_state))

            #res_state = (1 - self.leaky) * res_state + self.leaky * F.tanh(input_proj + res_proj)
            #print('fres_state', res_state.shape)
            #print( (1 - self.leaky), (0.2*res_state).shape)
            # Normalize reservoir state
            res_state = res_state / self.W_res_norm
            #print('here-1',res_state.shape )

            # Compute output
            # output = self.W_out(res_state)
            #print('ddd',output.shape)
            # Permute output to shape (sequence_length, batch_size, output_size)

            #output, self_attention_weights = self.self_attention(output, output, output)
            # Permute output back to shape (batch_size, sequence_length, output_size)
            #print("output.shape")
            #print("res_state shape:",res_state.shape)

            outputs.append(res_state.squeeze(0))
            #print("outputs lengt:", output)
        final_output = torch.stack(outputs, dim=0)
        #print("reservoir_output shape is:",final_output.permute(0, 2, 1).shape)

        return {'Output':final_output, "State": res_state}



class ReservoirTTimeSeries(nn.Module):
        # Initialize weights and apply final processing
        #self.post_init()
    def __init__(self, config):
        super().__init__()
        self.num_labels = config.num_labels
        self.config = config

        #self.bert_enc = BertGenerationEncoder(config)
        #self.bert_dec = BertGenerationDecoder(config)

        self.layer_norm = nn.LayerNorm(config.hidden_size)

        self.reservoirs=nn.ModuleList()
        self.id_train = None
        self.id_test = None
        self.reservoir_state = None
        self.state_ids = None

        for i in range(config.num_reservoirs):

            reservoir = DeepReservoirNet(config=config,
                                         reservoir_size=config.reservoir_size[i],
                                         spectral_radius=config.spectral_radius[i],
                                         leaky=config.leaky[i],
                                         activation_function=config.activation_function[i],
                                         sparsity=config.sparsity[i])

            self.reservoirs.append(reservoir)


    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        x_marks: Optional[torch.Tensor] = None,
        y_marks: Optional[torch.Tensor] = None,
        reservoir_ids: Optional[torch.Tensor] = None,
        state_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        labels_ids: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        dataset_type = None,
        train_dataset = None,
        eval_dataset = None,
        id = "id_train",
    ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        #print(id)

        if dataset_type == "eval_dataset":
            sample_size,_,_ = inputs_embeds.shape
            # 取出 column 并转换成 list
            inputs_embeds_list = list(train_dataset["inputs_embeds"])

            # 如果元素已经是 torch.Tensor
            inputs_embeds_tensor = torch.stack(inputs_embeds_list)
            inputs_embeds = torch.cat((inputs_embeds_tensor, inputs_embeds), dim=0)
        elif dataset_type == "test_dataset":
            sample_size,_,_ = inputs_embeds.shape
            inputs_embeds = torch.cat((train_dataset["inputs_embeds"],
                                       eval_dataset["inputs_embeds"],
                                       inputs_embeds), dim=0)

        if reservoir_ids is None:
            # Zero-pad the tensor in front
            padded_tensor = F.pad(inputs_embeds, (0, 0, 0, 0, 1, 0))  # (left, right, top, bottom) padding

            # Remove the last row
            reservoir_ids = padded_tensor[:-1]


        # Zero pad in front to make it (8, 8, 4)

        #print("reservoir_ids", reservoir_ids.shape, inputs_embeds.shape)
        #print("reservoir_ids",reservoir_ids)

        state_ids = [torch.zeros(self.config.hidden_size, self.config.reservoir_size[i]).float() for i in range(self.config.num_reservoirs)]
        reservoir_outputs=[]

        for i, reservoir in tqdm(enumerate(self.reservoirs)):

            reservoir_output = reservoir(reservoir_ids.float(), state_ids[i].to(inputs_embeds.device))
            output_re = reservoir_output['Output']
            res_state = reservoir_output['State']
            state_ids[i] = res_state

            reservoir_outputs.append(output_re)
            #if reservoir_outputs is not None:
            #    reservoir_outputs = torch.cat((reservoir_outputs,output_re), dim = 1)
            #else:
            #    reservoir_outputs = output_re

        #reservoir_outputs = reservoir_outputs/self.config.num_reservoirs
        # Transpose the lists

        transposed = list(zip(*reservoir_outputs))

        # Convert each tuple to a list (optional)
        reservoir_outputs = [list(tup) for tup in transposed]

        if dataset_type is None:
            return {"inputs_embeds":inputs_embeds,
                    "reservoir_outputs":reservoir_outputs,
                    "labels_ids":labels_ids}
        elif dataset_type == "eval_dataset":
            return {"inputs_embeds":inputs_embeds[-sample_size:],
                    "reservoir_outputs":reservoir_outputs[-sample_size:],
                    "labels_ids":labels_ids}
        elif dataset_type == "test_dataset":
            return {"inputs_embeds":inputs_embeds[-sample_size:],
                    "reservoir_outputs":reservoir_outputs[-sample_size:],
                    "labels_ids":labels_ids}


class Reservoir_fl_model(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.num_labels = config.num_labels
        self.config = config
        self.hidden_size = config.hidden_size
        self.re_output_size = config.re_output_size
        self.sequence_length = config.sequence_length
        self.batch_size = config.batch_size
        self.reservoir = ReservoirTTimeSeries(config)
        self.num_res = config.num_reservoirs
        self.cross_attn_layers = 3
        #self.input_projection = nn.Linear(self.hidden_size, 7)
        #self.label_projection = nn.Linear(self.hidden_size, 7)
        #self.output_projection = nn.Linear(7, self.hidden_size)
        patchTST_config = PatchTSTConfig(prediction_length=720,
                                       num_input_channels=int(self.config.embedding_size/self.config.hidden_size)*self.config.hidden_size,
                                       context_length = 336,
                                       num_attention_heads=16,
                                       patch_length = 16,
                                       patch_stride= 8,
                                       dropout=0.2,
                                       d_model=128,
                                       ffn_dim=256,
                                       head_dropout=0.2,
                                       scaling="std",
                                       pre_norm=True,
                                       norm_type="layernorm",
                                       channel_attention=False,
                                       random_mask_ratio=0.4,
                                       )
        self.patchtst = PatchTSTForPrediction(patchTST_config)
        TPGNconfig = TPGNConfig()
        self.TPGN = TPGN_fl(TPGNconfig)

        self.EmbeddingModel  = TimeSeriesEmbedding(self.config)
        self.ReservoirModel = ReservoirTTimeSeries(self.config)
        self.crossattn = nn.MultiheadAttention(int(self.config.embedding_size/self.config.hidden_size)*self.config.hidden_size,
                                               kdim =self.config.embedding_size,
                                               vdim= self.config.embedding_size,
                                               num_heads=7,
                                               batch_first=True,
                                               dropout= 0.2)
        self.dropout = nn.Dropout(0.2)
        self.decoder = nn.Linear(int(self.config.embedding_size/self.config.hidden_size)*self.config.hidden_size,self.config.hidden_size)
        self.W_outputs = nn.ModuleList()
        for i in range(self.num_res):
            #self.W_outputs.append(nn.Linear(config.reservoir_size[i], self.config.output_size).float())
            self.W_outputs.append(
                nn.Sequential(
                    nn.Linear(config.reservoir_size[i], self.config.output_size),
                    nn.ReLU(),  # 加入非线性激活
                    nn.Linear(self.config.output_size, self.config.output_size)).float())

        self.cross_attns = nn.ModuleList()

        self.cross_attn_norms = nn.ModuleList()
        #self.time_feature_encoder = self.Time_feature_encoder(self.period_type)
        #self.shape_shift = nn.Linear(int(self.config.embedding_size/self.config.hidden_size)*self.config.hidden_size+self.period_feature_num,int(self.config.embedding_size/self.config.hidden_size)*self.config.hidden_size)
        for _ in range(self.cross_attn_layers):
            self.cross_attns.append(
                nn.MultiheadAttention(
                    embed_dim=int(self.config.embedding_size/(self.config.hidden_size))*(self.config.hidden_size),
                    kdim=self.config.embedding_size,
                    vdim=self.config.embedding_size,
                    num_heads=config.hidden_size,
                    batch_first=True,
                    dropout=0.2
                )
            )
            self.cross_attn_norms.append(
                nn.LayerNorm(int(self.config.embedding_size/(self.config.hidden_size))*(self.config.hidden_size))
            )


    def forward(self,
            inputs_embeds: Optional[torch.Tensor] = None,
            reservoir_outputs: list = None,
            labels_ids: Optional[torch.Tensor] = None,):

        labels = labels_ids
        print("labels shape",labels.shape)

        inputs_embeds = self.EmbeddingModel(inputs_embeds)

        reservoir_outputs_fl = []
        for output,W_out in zip(reservoir_outputs, self.W_outputs):
            #print(output.shape)
            output = W_out(output)
            reservoir_outputs_fl.append(output)


        reservoir_outputs = torch.cat(reservoir_outputs_fl, dim = 1)

        #labels = self.EmbeddingModel(labels)

        #if reservoir_outputs is None:
        #    self.reservoir_outputs = self.ReservoirModel(total_inputs)

        # Get the outputs from the reservoir module

        #reservoir_outputs = self.reservoir(inputs_embeds=inputs_embeds)

        # Create a TensorDataset and DataLoader
        #dataset = TensorDataset(inputs_embeds, reservoir_outputs, labels)
        #dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False)

        # Lists to store the BERT outputs
        #input_aft_attn = []
        #labels_ful = []

        # Initialize loss variable
        #final_loss = None

        # Loop over batches and apply cross-attention and BERT model
        #inputs_batch, reservoir_batch, labels_batch = batch

        # Apply cross-attention between inputs and reservoir outputs
        #attn_output, attn_weight = self.crossattn(inputs_embeds.float(), reservoir_outputs.float(), reservoir_outputs.float(),need_weights=False)
        for i in range(self.cross_attn_layers):
            # Cross-attention计算
            attn_layer_output, _ = self.cross_attns[i](
                query=inputs_embeds.float(),
                key=reservoir_outputs.float(),
                value=reservoir_outputs.float(),
                need_weights=False
            )

            # 残差连接 + LayerNorm
            attn_output = inputs_embeds + self.dropout(attn_layer_output)
            attn_output = self.cross_attn_norms[i](attn_output)
        #print(attn_output.shape)
        #attn_output = self.input_projection(attn_output.float())
        outputs = self.patchtst(past_values=attn_output)
        prediction = outputs["prediction_outputs"].float()
        prediction = self.decoder(prediction)
        # Feed the attention output into the BERT model
        loss = None
        if labels_ids is not None:
            labels_ids = labels_ids.float()  # Make sure the labels are of type float for loss calculation

            # Use Mean Squared Error (MSE) Loss
            loss_fn = F.mse_loss  # You can change to F.l1_loss for MAE

            # Calculate the loss between the predictions and the true labels
            loss = loss_fn(prediction, labels_ids)

            mae_loss_fn = F.l1_loss
            mae_loss = mae_loss_fn(prediction, labels_ids)

            wandb.log({"mse_loss": loss,"mae_loss":mae_loss})


        #projected_logits = self.output_projection(outputs.logits)
        #all_logits.append(projected_logits)

        # Concatenate all batch outputs to form the final output tensors
        #final_logits = torch.cat(all_logits, dim=0)  # Concatenate logits along batch dimension

        # Compute final loss if labels are provided (average over batches)
        #if labels is not None:
        #    final_loss = torch.mean(torch.stack(all_losses))
        #print(torch.cat(input_aft_attn, dim=0).shape)
        # Return in BERT-like output format
        if labels is not None:
            return {"loss": loss.float(),
                    "prediction_outputs": prediction.float()}
        else:
            return {"prediction_outputs": prediction.float()}

def extract_inputs_and_labels(dataset):

    loader = DataLoader(dataset, batch_size=32, shuffle=False)

    inputs_embeds_list = []
    labels_ids_list = []
    x_marks_list = []
    y_marks_list = []

    for batch in tqdm(loader, desc="Training Batches"):
        #print(batch["labels_ids"].shape)
        inputs_embeds_list.append(batch['inputs_embeds'])
        x_marks_list.append(batch['x_marks'])
        y_marks_list.append(batch['y_marks'])
        labels_ids_list.append(batch['labels_ids'])

    return {"inputs_embeds": torch.cat(inputs_embeds_list, dim=0),
            "labels_ids": torch.cat(labels_ids_list, dim=0)}


#from time_data_normalize import Dataset_ETT_hour
import numpy as np
# prepare data for lstm
from sklearn.preprocessing import StandardScaler
from pandas import read_csv
from pandas import DataFrame
import pandas as pd
import random
from sklearn.model_selection import train_test_split
from pandas import concat
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import MinMaxScaler
from torch.utils.data import Dataset
torch.autograd.set_detect_anomaly(True)

if __name__ == "__main__":


    dataset = read_csv(
    "ETTh1.csv")   # Replace with your CSV file pat
    dataset = dataset.dropna()

    #df = dataset.drop(dataset.columns[[0, 1, 2, 3]], axis=1).astype('float32')
    #df = dataset
    #processed_df = dataset.iloc[1:]

    #dataset = dataset.apply(lambda x: x.str.replace(',', '.').astype(float) if x.dtype == 'object' and x.str.contains(',').any() else x)
    #for col in processed_df.columns:
        # 移除千位分隔符并转换类型
    #    print(col)
    #    processed_df[col] = processed_df[col].replace('[",]', '', regex=True).astype(float)

    #dataset = dataset.drop(dataset.columns[:0],axis=1)
    #dataset = processed_df.dropna()
    dataset['date'] = pd.to_datetime(dataset['date'])

    # 创建时间特征
    dataset['hour'] = dataset['date'].dt.hour
    dataset['dayofweek'] = dataset['date'].dt.dayofweek
    dataset['day'] = dataset['date'].dt.day
    dataset['month'] = dataset['date'].dt.month

    dataset = dataset.drop(['date'], axis = 1)
    time_features = dataset[['hour', 'dayofweek', 'day', 'month']].values
    dataset = dataset.drop(['hour', 'dayofweek', 'day', 'month'], axis = 1)
    #print(time_features.shape)
    #print(dataset.shape)
    X = dataset.values

    scaler = StandardScaler()
    X = scaler.fit_transform(X)



    #X=X[1:]

    #Reservoir_id = np.array([[0] * len(X[0])] + X[:-1].tolist())
    # Create a zero column of shape (100, 1)
    '''
    zero_col = np.zeros((X.shape[0], 1))

    # Concatenate the original array with the zero column along the second axis (columns)
    X = np.hstack((X, zero_col))
    #X =  dataset.drop(['ate'], axis = 1).values

    #X_train, X_test, y_train, y_test =train_test_split(X.values, y, test_size=0.2, shuffle=False)
    '''

    from tqdm.auto import tqdm
    # 1. Preprocess the data into the required format
    def create_sequences(data,time_features, seq_length, pred_length):
        sequences = []
        seq_x_time = []
        targets = []
        seq_y_time = []
        for i in tqdm(range(len(data) - seq_length - pred_length + 1)):
            sequences.append(data[i:i+seq_length])
            seq_x_time.append(time_features[i:i+seq_length])
            targets.append(data[i+seq_length:i+seq_length+pred_length])
            seq_y_time.append(time_features[i+seq_length:i+seq_length+pred_length])
        return torch.tensor(sequences), torch.tensor(seq_x_time), torch.tensor(targets), torch.tensor(seq_y_time)

    X,x_marks,y,y_marks = create_sequences(X, time_features,seq_length=configuration.sequence_length, pred_length=configuration.pred_len)
    # Zeros tensor of shape [16941, 384, 1]
    # print(X.shape)
    #zeros = torch.zeros((X.size(0), X.size(1), 9), dtype=X.dtype)

    # Concatenate along the last dimension
    #X = torch.cat((X, zeros), dim=-1)
    #print(X.shape)
    #print(x_marks.shape)
    #print(y.shape)

    batch=100
    indices = np.arange(len(X))
    barrier = int(len(indices)/batch)*batch
    indices = indices[0:barrier]
    soft_border = int((configuration.sequence_length/batch))+8

    indices = [indices[i:i+batch] for i in range(0, len(indices), batch)]

    border1 = int(len(indices)*0.7)
    border2 = border1+int(len(indices)*0.1)
    border3 = border2+int(len(indices)*0.2)

    train_ind = indices[0:border1]
    val_ind = indices[border1-soft_border: border2]
    test_ind = indices[border2-soft_border: border3]

    # random.shuffle(train_ind)
    # random.shuffle(val_ind)
    #random.shuffle(test_ind)


    X_train = [X[item] for sublist in train_ind for item in sublist]
    x_marks_train = [x_marks[item] for sublist in train_ind for item in sublist]
    y_train = [y[item] for sublist in train_ind for item in sublist]
    print(y_train[0])
    y_marks_train = [y_marks[item] for sublist in train_ind for item in sublist]

    X_val = [X[item] for sublist in val_ind for item in sublist]
    x_marks_val = [x_marks[item] for sublist in val_ind for item in sublist]
    y_val = [y[item] for sublist in val_ind for item in sublist]
    y_marks_val = [y_marks[item] for sublist in val_ind for item in sublist]

    X_test = [X[item] for sublist in test_ind for item in sublist]
    x_marks_test = [x_marks[item] for sublist in test_ind for item in sublist]
    y_test = [y[item] for sublist in test_ind for item in sublist]
    y_marks_test = [y_marks[item] for sublist in test_ind for item in sublist]

#train_indices, test_indices =train_test_split(indices,  test_size=0.2, shuffle=False)
#indices = [item for sublist in indices for item in sublist]

import torch
from torch.utils.data import Dataset, DataLoader

class CustomDataset(Dataset):
    def __init__(self, tokenized_inputs,  x_marks = None, y_marks = None, labels=None, pos=None):
        self.tokenized_inputs = tokenized_inputs
        self.x_marks = x_marks
        self.y_marks = y_marks
        self.labels = labels
        self.pos = pos
        self.id_list = None
        self.re = None

    def __len__(self):
        return len(self.tokenized_inputs)

    def __getitem__(self, idx):
        if self.labels is not None:
            return {
                "inputs_embeds": torch.tensor(self.tokenized_inputs[idx]).float(),
                "x_marks": torch.tensor(self.x_marks[idx]).float(),
                "labels_ids": torch.tensor(self.labels[idx]).float(),
                "y_marks": torch.tensor(self.y_marks[idx]).float(),
                #"id": torch.tensor(self.id_list[idx]),  # Include the id directly
                #"reservoir_ids": torch.tensor(self.re[idx]),
            }
        else:
            return {
                "inputs_embeds": torch.tensor(self.tokenized_inputs[idx]).float(),
                "x_marks": torch.tensor(self.x_marks[idx]).float(),
            }

# Assuming you have X_train, y_train, X_test, y_test, trainpos, and testpos defined
#print(CustomDataset(X_train, x_marks_train ,y_train,y_marks_train)[0]["x_marks"].shape[-1])

if __name__ == "__main__":
    # print(X_train[0], flush=True)
    train_dataset = CustomDataset(X_train,x_marks_train,y_marks_train,y_train)
    # print(train_dataset[0], flush=True)
    #print("train_dataset[0][labels_ids].shape",train_dataset[0]["labels_ids"].shape)

    val_dataset = CustomDataset(X_val,x_marks_val, y_marks_val, y_val)

    test_dataset = CustomDataset(X_test,x_marks_test, y_marks_test, y_test)

    preprocess = ReservoirTTimeSeries(configuration)
    train_dataset_dic = extract_inputs_and_labels(train_dataset)
    #val_dataset_dic = extract_inputs_and_labels(val_dataset)
    test_dataset_dic = extract_inputs_and_labels(val_dataset)
    #preprocess(inputs_embeds = train_dataset_dic["inputs_embeds"],labels_ids = train_dataset_dic["labels_ids"])

    from datasets import Dataset, concatenate_datasets
    from tqdm import tqdm

    def build_dataset_from_dict_in_chunks(dic, batch_size=1024):
        datasets = []
        length = len(dic["inputs_embeds"])

        for start in tqdm(range(0, length, batch_size), desc="Building dataset", ncols=100):
            end = min(start + batch_size, length)
            batch = {k: v[start:end] for k, v in dic.items()}
            datasets.append(Dataset.from_dict(batch))  # 每次只建一个小dataset

        dataset = concatenate_datasets(datasets)  # 拼接成一个大dataset
        dataset.set_format(type="torch")
        return dataset
    dict_in_process = preprocess(inputs_embeds = train_dataset_dic["inputs_embeds"],
                                  labels_ids = train_dataset_dic["labels_ids"])
    train_dataset_fl = build_dataset_from_dict_in_chunks(dict_in_process, batch_size=1024)
    train_dataset_fl.set_format(type='torch', columns=['inputs_embeds'])

    print(train_dataset_fl["inputs_embeds"])

    train_dataset_fl.set_format(type='torch')
    del dict_in_process
    test_dict = preprocess(inputs_embeds = test_dataset_dic["inputs_embeds"],
                                 labels_ids = test_dataset_dic["labels_ids"],
                                 dataset_type = "eval_dataset",
                                 train_dataset = train_dataset_fl)
    test_dataset_fl = build_dataset_from_dict_in_chunks(test_dict, batch_size=1024)
    test_dataset_fl.set_format(type='torch')
    del test_dict
    #print("train_dataset_fl",train_dataset_fl[0])
    #print("train_dataset_fl input_embs",train_dataset_fl["inputs_embeds"][0].shape)
    #print("train_dataset_fl reservoir_outputs",train_dataset_fl["reservoir_outputs"][0].shape)


#embedding_model = TimeSeriesEmbedding(configuration)
#reservoir_model = ReservoirTTimeSeries(configuration)
#fl_model = Reservoir_fl_model(configuration)
#dataloader = DataLoader(train_dataset,batch_size=64,shuffle = False)

#for batch in dataloader:
#    inputs_embeds = batch["inputs_embeds"]
#    label_ids = batch["labels_ids"]
#    inputs_embeds = embedding_model(inputs_embeds)
#    #print(inputs_embeds.shape)
#    #reservoir_output,reservoir_state = reservoir_model(inputs_embeds = inputs_embeds)
#    result = fl_model(inputs_embeds = inputs_embeds)
#    break

training_args = TrainingArguments(
    output_dir="./checkpoint/patchtst/ETTh1/pretrain/last_hope_output/",
    overwrite_output_dir=True,
    learning_rate=0.002,
    num_train_epochs=100,
    do_eval=True,
    eval_strategy="epoch",
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    save_strategy="epoch",
    logging_strategy="epoch",
    save_total_limit=3,
    logging_dir="./checkpoint/patchtst/ETTh1/pretrain/logs/",  # Make sure to specify a logging directory
    load_best_model_at_end=True,  # Load the best model when training ends
    metric_for_best_model="eval_loss",  # Metric to monitor for early stopping
    greater_is_better=False,  # For loss
    label_names=["labels_ids"],
)

early_stopping_callback = EarlyStoppingCallback(
    early_stopping_patience=20,  # Number of epochs with no improvement after which to stop
    early_stopping_threshold=0.001,  # Minimum improvement required to consider as improvement
)
#print(train_dataset[0])
#print(train_dataset[0].keys())



class ReservoirTrainer(Trainer):
    def get_train_dataloader(self) -> DataLoader:
       train_dataset = self.train_dataset
       return DataLoader(train_dataset, shuffle=True, batch_size=32)

    def get_eval_dataloader(self, eval_dataset=None) -> DataLoader:
       if eval_dataset is None:
           eval_dataset = self.eval_dataset
       return DataLoader(eval_dataset, shuffle=True, batch_size=32)

    def get_test_dataloader(self, test_dataset=None) -> DataLoader:
       if test_dataset is None:
           test_dataset = self.test_dataset
       return DataLoader(test_dataset, shuffle=True, batch_size=32)

model  = Reservoir_fl_model(configuration)
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

if __name__ == "__main__":
    trainer = ReservoirTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset_fl,
        eval_dataset=test_dataset_fl,
        # callbacks=[early_stopping_callback],
        # compute_metrics=compute_metrics,
    )
#print(type(train_dataset_fl["labels_ids"]))
# pretrain
    trainer.train()
# Training loop

#res_state = torch.zeros(1, 1000)
#for batch in train_loader:
#    inputs_embeds = batch['inputs_embeds']  # Extract input sequences from the batch
#    labels_ids = batch['labels_ids']        # Extract target sequences from the batch#

    # Forward pass through the DeepReservoirNet
#    reservoir_outputs = model(inputs_embeds=inputs_embeds)

    # Get the model's outputs and updated reservoir state
    #outputs = output_dict['Output']
    #res_state = output_dict['State']
#    print(reservoir_outputs.shape) #the output shape is (batch_size,output_size,num_features)
#    print(reservoir_outputs)
#    break  #next step is to keep track of all Reservoir states across all batches
           #next step is to use the cross attetnion to combine input and reservoir_outputs

# Step 4: Forward pass through the model
#output_dict = model(train_dataset)
#print(res_state.shape)