import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
import os
import sys
BASE_PATH = os.environ.get("BASE_PATH", "")
if BASE_PATH and BASE_PATH.endswith('/'):
    BASE_PATH = BASE_PATH[:-1]
path = f"{BASE_PATH}/src/models/sequence/"
path2 = f"{BASE_PATH}/"
sys.path.append(path)
sys.path.append(path2)
import os
os.chdir(path)
from cross_encoder import Encoder
from cross_decoder import Decoder
from attn import FullAttention, AttentionLayer, TwoStageAttentionLayer
from cross_embed import DSW_embedding

from math import ceil

class Crossformer(nn.Module):
    def __init__(self, data_dim, in_len, out_len, seg_len, win_size = 4,
                factor=10, d_model=512, d_ff = 1024, n_heads=8, e_layers=3, 
                dropout=0.0, baseline = False, device=torch.device('cuda:0'), *args, **kwargs):
        super(Crossformer, self).__init__()
        self.data_dim = data_dim
        #self.data_dim = 99
        self.in_len = in_len
        self.out_len = out_len
        self.seg_len = seg_len
        self.merge_win = win_size
        self.d_model = d_model

        self.baseline = baseline

        self.device = device

        # The padding operation to handle invisible sgemnet length
        self.pad_in_len = ceil(1.0 * in_len / seg_len) * seg_len
        self.pad_out_len = ceil(1.0 * out_len / seg_len) * seg_len
        self.in_len_add = self.pad_in_len - self.in_len

        # Embedding
        self.enc_value_embedding = DSW_embedding(seg_len, d_model)
        self.enc_pos_embedding = nn.Parameter(torch.randn(1, data_dim, (self.pad_in_len // seg_len), d_model))
        self.pre_norm = nn.LayerNorm(d_model)

        # Encoder
        self.encoder = Encoder(e_layers, win_size, d_model, n_heads, d_ff, block_depth = 1, \
                                    dropout = dropout,in_seg_num = (self.pad_in_len // seg_len), factor = factor)
        
        # Decoder
        self.dec_pos_embedding = nn.Parameter(torch.randn(1, data_dim, (self.pad_out_len // seg_len), d_model))
        self.decoder = Decoder(seg_len, e_layers + 1, d_model, n_heads, d_ff, dropout, \
                                    out_seg_num = (self.pad_out_len // seg_len), factor = factor)
        
    def forward(self, x_seq, *args, **kwargs):
        
        if (self.baseline):
            base = x_seq.mean(dim = 1, keepdim = True)
        else:
            base = 0
        batch_size = x_seq.shape[0]
        
        if (self.in_len_add != 0):
            x_seq = torch.cat((x_seq[:, :1, :].expand(-1, self.in_len_add, -1), x_seq), dim = 1)

        x_seq = self.enc_value_embedding(x_seq)
        x_seq += self.enc_pos_embedding
        x_seq = self.pre_norm(x_seq)
        
        enc_out = self.encoder(x_seq)
        
        dec_in = repeat(self.dec_pos_embedding, 'b ts_d l d -> (repeat b) ts_d l d', repeat = batch_size)
        predict_y = self.decoder(dec_in, enc_out)


        return base + predict_y[:, :self.out_len, :],0
        #return torch.mean(enc_out[-1],axis=1),0
    
    @property
    def d_state(self):
        return self.H

    @property
    def d_output(self):
        return self.d_model
    
if __name__ == "__main__":
    # Example data dimensions and parameters
    data_dim = 5
    in_len = 100
    out_len = 50
    seg_len = 10
    bz = 12
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    
    # Initialize the Crossformer model
    model = Crossformer(data_dim=data_dim, in_len=in_len, out_len=out_len, seg_len=seg_len, device=device)
    
    # Example input tensor
    x_seq = torch.randn(12, in_len, data_dim).to(device)  # Batch size of 1

    # Forward pass
    
    try:
        output,_ = model(x_seq)
        print("Output shape:", output.shape)
        assert output.shape == (1, out_len, data_dim), "Output shape is incorrect!"
        print("Basic sanity check passed.")
    except Exception as e:
        print("Sanity check failed:", e)

    # Further checks can include testing different configurations,
    # testing with varying segment lengths, or testing edge cases.
