import torch

from models.itransformer_conv import iTransformer_conv
from models.timemixer_conv import TimeMixer_conv
from models.dlinear_conv import DLinear_conv
from models.patchtst_conv import PatchTST_conv
from models.itransformer import iTransformer
from models.patchtst import PatchTST
from models.timemixer import TimeMixer
from models.dlinear import DLinear
from models.bin_tabl import BiN_CTABL
from models.mlp import MLP
from models.lstm import LSTM
from models.dain import DAIN
from models.translob import TransLob
from models.cnn1 import CNN1

class DictToClass:
    def __init__(self, dictionary):
        for key, value in dictionary.items():
            setattr(self, key, value)

class Config:
    def __init__(self, enc_in=40, seq_len=100, pred_len=1, conv_k=2, conv_d=2, in_c=40, out_c=14, e_layers=None, 
    n_heads=None, d_model=None, d_ff=None, dropout=None, fc_dropout=None,
    head_dropout=None, individual=None, patch_len=None, freq=None,
    stride=None, padding_patch=None, revin=None, affine=None,
    subtract_last=None, decomposition=None, kernel_size=None, task_name=None):
        self.enc_in = enc_in  # Number of input features to encoder
        self.seq_len = seq_len  # Input sequence length (lookback)
        self.pred_len = pred_len  # Prediction length (number of steps to predict)
        self.conv_k = conv_k
        self.freq = freq
        self.conv_d = conv_d
        self.out_c = out_c
        self.e_layers = e_layers  # Number of encoder layers in the transformer
        self.n_heads = n_heads  # Number of attention heads
        self.d_model = d_model  # Dimensionality of the model
        self.d_ff = d_ff  # Dimensionality of the feed-forward layer
        self.dropout = dropout  # Dropout rate
        self.fc_dropout = fc_dropout  # Dropout rate in the fully connected layer
        self.head_dropout = head_dropout  # Dropout rate at the output head
        self.individual = individual  # If set, handles model branching for individual tasks
        self.patch_len = patch_len  # Patch length
        self.stride = stride  # Stride for patching
        self.padding_patch = padding_patch  # Padding for patching
        self.revin = revin  # Whether to include reversible layers
        self.affine = affine  # Whether to use affine transformation in layers
        self.subtract_last = subtract_last  # Whether to subtract the last element in sequences
        self.decomposition = decomposition  # Whether to decompose input sequences
        self.kernel_size = kernel_size  # Kernel size for decomposition

def load_model(model_name, num_features, out_c, kernel, dilation, num_conv, conv_type, depthwise, bsz, lookback=100, pred_len=1, out_variates=1):
    model_name = model_name.lower()
    if out_c == 'full':
        out_c = num_features
    if out_c == 'half':
        out_c = (num_features) // 2
    if out_c == 'quarter':
        out_c = (num_features) // 4
    if out_c == 'tlb':
        out_c = 14
    print(f'Number of features: {out_c}')


    if model_name == 'itransformer_conv':
        model = iTransformer_conv(
            seq_len=lookback, pred_len=pred_len, output_attention='store_true',
            use_norm=True, d_model=512, embed='timeF',
            freq='h', dropout=0.1, class_strategy='projection',
            factor=1 , n_heads=8,
            d_ff=2048, activation='gelu', e_layers=2,
            in_c=num_features, out_c=out_c, kernel=kernel, dilation=dilation,
            num_conv=num_conv, conv_type=conv_type, depthwise=False, out_variates=out_variates
        )
    
    elif model_name.lower() == 'patchtst_conv':
        config = Config(
            enc_in=out_c, seq_len=lookback, pred_len=pred_len, 
            e_layers=3, n_heads=4, d_model=16, d_ff=128, 
            dropout=0.3, fc_dropout=0.3, head_dropout=0, 
            individual=False, 
            patch_len=16, stride=8, padding_patch=0, 
            revin=True, affine=True, subtract_last=False, decomposition=True, 
            kernel_size=25
        )   

        model = PatchTST_conv(
            config, 
            in_c=num_features, out_c=out_c, kernel=kernel, dilation=dilation,
            num_conv=num_conv, conv_type=conv_type, depthwise=False, out_variates=out_variates
        )

    elif model_name.lower() == 'dlinear_conv':
        config = Config(
            seq_len=lookback, pred_len=pred_len, individual=True
        )
        
        model = DLinear_conv(
            config, 
            in_c=num_features, out_c=out_c, kernel=kernel, dilation=dilation, 
            num_conv=num_conv, conv_type=conv_type, depthwise=False, out_variates=out_variates)

    elif model_name.lower() == 'timemixer_conv':

        config = {
            'pred_len': pred_len, 
            'seq_len': lookback,
            'd_model': 16,
            'e_layers': 2,
            'dropout': 0.1,
        }
        model = TimeMixer_conv(
            DictToClass(config), 
            in_c=num_features, out_c=out_c, kernel=kernel, dilation=dilation,
            num_conv=num_conv, conv_type=conv_type, depthwise=False)

    elif model_name == 'itransformer':
        model = iTransformer(
            seq_len=lookback, pred_len=pred_len, output_attention='store_true',
            use_norm=True, d_model=512, embed='timeF',
            freq='h', dropout=0.1, class_strategy='projection',
            factor=1 , n_heads=8,
            d_ff=2048, activation='gelu', e_layers=2, out_variates=out_variates
        )

    elif model_name == 'patchtst':
        configs = Config(
            enc_in=num_features, seq_len=lookback, pred_len=pred_len, 
            e_layers=3, n_heads=4, d_model=16, d_ff=128, 
            dropout=0.3, fc_dropout=0.3, head_dropout=0, 
            individual=False, 
            patch_len=16, stride=8, padding_patch=0, 
            revin=True, affine=True, subtract_last=False, decomposition=True, 
            kernel_size=25
        )
        model = PatchTST(configs, out_variates=out_variates)

    elif model_name == 'timemixer':
        config = {
            'pred_len': pred_len, 
            'seq_len': lookback,
            'd_model': 16,
            'e_layers': 2,
            'dropout': 0.1,
            'enc_in': num_features
        }
        model = TimeMixer(DictToClass(config), out_variates=out_variates)

    elif model_name == 'dlinear':
        configs = Config(enc_in=num_features, seq_len=lookback, pred_len=pred_len, individual=True)
        model = DLinear(configs, out_variates=out_variates)

    elif model_name == 'binctabl':
        model = BiN_CTABL(60, num_features, lookback, 10, 120, 5, 1, 1)

    elif model_name == 'mlp':
        model = MLP(num_features=lookback * (num_features), num_classes=1)

    elif model_name == 'lstm':
        model = LSTM(x_shape=num_features, hidden_layer_dim=32, hidden_mlp=64, num_layers=1, p_dropout=0.2)

    elif model_name == 'cnn1':
        model = CNN1(num_features=num_features, num_classes=1)

    elif model_name == 'dain':
        model = net_architecture = DAIN(
            backward_window=lookback,
            num_features=num_features,
            num_classes=1,
            mlp_hidden=512,
            p_dropout=0.5,
            mode='full',
            mean_lr=1e-06,    # 1e-06
            scale_lr=1e-03,   # 1e-02
            gate_lr=10        # 1e-02
        )

    elif model_name == 'translob':
        model = TransLob(seq_len=lookback, in_c=num_features, btch_sz=bsz)

    return model


def load_optimizer(name, model, lr):
    if name == 'dain':
        optimizer = torch.optim.RMSprop([
            {'params': model.base.parameters()},
            {'params': model.dean.mean_layer.parameters(),
            'lr': 0.0001 * model.dean.mean_lr},
            {'params': model.dean.scaling_layer.parameters(),
            'lr': 0.0001 * model.dean.scale_lr},
            {'params': model.dean.gating_layer.parameters(),
            'lr': 0.0001 * model.dean.gate_lr},
        ], lr=lr)
        return optimizer
    else:
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    return optimizer

def retrieve_params(m, d):
    if d == 'FI':
        if m == 'lstm': return [100, 0.01, 64]
        elif m == 'mlp': return [100, 0.0001, 32]
        elif m == 'cnn1': return [100, 0.0001, 32]
        elif m == 'binctabl': return [10, 0.0001, 64]
        elif m == 'dain': return [15, 0.0001, 128]
        elif m == 'translob': return [100, 0.0001, 128]
        elif m in ['timemixer', 'timemixer_conv', 'dlinear', 'dlinear_conv',
                   'itransformer', 'itransformer_conv', 'patchtst', 'patchtst_conv', 'timexer']:
            return [100, 0.0001, 128]
        else:
            print("Model does not exist...")
            exit()
    elif d == 'CHF':
        return [100, 0.0001, 1024]