import os
import argparse
import numpy as np
import pandas as pd
# import matplotlib.pyplot as plt
import math
import time

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset

# import pytorch_lightning as pl
from sklearn.metrics import mean_squared_error, r2_score
from scipy.stats import pearsonr
from tqdm import tqdm
import sys


from layers.Transformer_EncDec import Encoder, EncoderLayer
from layers.SelfAttention_Family import FullAttention, AttentionLayer
from layers.Embed import DataEmbedding_inverted
from layers.conv_layer import Conv_Lob
from data_preprocessing.data_utils import load_data

import pickle

# Constants
LEARNING_RATE = 0.0001
BATCH_SIZE = 64 * 2
BATCH_SIZE = 128
EPOCHS = 15
WEIGHT_DECAY = 0
GAMMA = None
STEP_SIZE = None
SCHEDULER_FLAG = False
SCHEDULER_STEP_FLAG = False
LOSS_TYPE = 'MSE'
LOSS_CALCULATION_MP_FLAG = True



class LeftPad1d(nn.Module):
    def __init__(self, left_pad):
        super(LeftPad1d, self).__init__()
        self.left_pad = left_pad

    def forward(self, x):
        return F.pad(x, (self.left_pad, 0))
    
    

class iTransformer_conv(nn.Module):
    """
    Paper link: https://arxiv.org/abs/2310.06625
    """

    def __init__(self, seq_len, pred_len, use_norm, d_model, embed, freq, dropout, 
                 class_strategy, factor, n_heads, d_ff, activation, e_layers, 
                 output_attention='attention_true', in_c=41, out_c=14, kernel=2, dilation=2,
                 num_conv=5, conv_type='exp'):
        super().__init__()

        self.conv = Conv_Lob(
            conv_type=conv_type, in_c=in_c, out_c=out_c, kernel=kernel,
            dilation=dilation, num_layers=num_conv
            )

        self.seq_len = seq_len
        self.pred_len = pred_len
        self.output_attention = output_attention
        self.use_norm = use_norm

        # Embedding
        self.enc_embedding = DataEmbedding_inverted(seq_len, d_model, embed, freq,
                                                    dropout)
        self.class_strategy = class_strategy
        # Encoder-only architecture
        self.encoder = Encoder(
            [
                EncoderLayer(
                    AttentionLayer(
                        FullAttention(False, factor, attention_dropout=dropout,
                                      output_attention=output_attention), d_model, n_heads),
                    d_model,
                    d_ff,
                    dropout=dropout,
                    activation=activation
                ) for l in range(e_layers)
            ],
            norm_layer=torch.nn.LayerNorm(d_model)
        )
        self.projector = nn.Linear(d_model, pred_len, bias=True)
        self.attns = None

    def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
        # B L C, C = 40

        # Pass the input tensor through a series of convolutional layers
        x_enc = self.conv(x_enc)

        # B L N, N = 14

        if self.use_norm:
            # Normalization from Non-stationary Transformer
            means = x_enc.mean(1, keepdim=True).detach()
            x_enc = x_enc - means
            stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
            x_enc = x_enc / stdev
        _, _, N = x_enc.shape # B L N
        # B: batch_size;    E: d_model;
        # L: seq_len;       S: pred_len;
        # N: number of variate (tokens), can also includes covariates

        # Embedding
        # B L N -> B N E                (B L N -> B L E in the vanilla Transformer)
        enc_out = self.enc_embedding(x_enc, x_mark_enc) # covariates (e.g timestamp) can be also embedded as tokens
        # B N E -> B N E                (B L E -> B L E in the vanilla Transformer)
        # the dimensions of embedded time series has been inverted, and then processed by native attn, layernorm and ffn modules
        enc_out, attns = self.encoder(enc_out, attn_mask=None)
        self.attns = attns[-1].detach().cpu()

        # B N E -> B N S -> B S N
        dec_out = self.projector(enc_out).permute(0, 2, 1)[:, :, :N] # filter the covariates
        if self.use_norm:
            # De-Normalization from Non-stationary Transformer
            dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
            dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
        # print(f"dec_out shape: {dec_out.shape}")

        return dec_out
    
    def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
        dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
        return dec_out[:, -self.pred_len:, -1]  # [B, L, D]
    
def evaluate_predictions(y_true, y_pred):
    y_true_np = y_true.detach().cpu().numpy()
    y_pred_np = y_pred.detach().cpu().numpy()
    mse = mean_squared_error(y_true_np, y_pred_np)
    corr = np.corrcoef(y_true_np.flatten(), y_pred_np.flatten())[0, 1]
    r2 = r2_score(y_true_np, y_pred_np)
    return mse, corr, r2



def run_experiment(train_loader, val_loader, test_loader, num_features, PRED_LEN, LEARNING_RATE, SEED, WEIGHT_DECAY, BATCH_SIZE, EPOCHS, LOSS_TYPE, LOSS_CALCULATION_MP_FLAG, SCHEDULER_FLAG=False, SCHEDULER_STEP_FLAG=False, STEP_SIZE=None, GAMMA=None):
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    np.random.seed(SEED)
    
    model = iTransformer_conv(
            seq_len=100, pred_len=1, output_attention='store_true',
            use_norm=True, d_model=512, embed='timeF',
            freq='h', dropout=0.1, class_strategy='projection',
            factor=1 , n_heads=8,
            d_ff=2048, activation='gelu', e_layers=2,
            in_c=num_features+1, out_c=20
        )
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    model_name = "itransformer_conv"
    patience, check = 15, 0
    best_val_loss = math.inf
    
    for epoch in range(EPOCHS):
        # Training
        model.train()
        train_loss = 0.0
        y_true = []
        y_pred = []
        for inputs, targets in train_loader:
            inputs, targets = inputs.to(device), targets.to(device)  # Move data to device
            # Forward pass
            outputs = model(inputs, None, None, None)
            loss = criterion(outputs.reshape(-1,1), targets.reshape(-1, 1))
            # loss = criterion(outputs[:, :, num_features], targets.reshape(-1, 1))
            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item() * inputs.size(0)
            y_true.extend(targets.reshape(-1, 1).detach().cpu().numpy())  
            y_pred.extend(outputs.reshape(-1, 1).detach().cpu().numpy())
            
        y_true = np.concatenate(y_true)
        y_pred = np.concatenate(y_pred)
        mse_train = mean_squared_error(y_true, y_pred)
        r2_train = r2_score(y_true, y_pred)
        corr_train, _ = pearsonr(y_true, y_pred)
        
        # Validation
        model.eval()
        val_loss = 0.0
        y_true = []
        y_pred = []
        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.to(device), targets.to(device)  # Move data to device
                outputs = model(inputs, None, None, None)
                loss = criterion(outputs.reshape(-1, 1), targets.reshape(-1, 1))
                val_loss += loss.item() * inputs.size(0)
                y_true.extend(targets.reshape(-1, 1).detach().cpu().numpy())  
                y_pred.extend(outputs.reshape(-1, 1).detach().cpu().numpy())
        
        # Calculate metrics for validation
        y_true = np.concatenate(y_true)
        y_pred = np.concatenate(y_pred)
        mse_val = mean_squared_error(y_true, y_pred)
        r2_val = r2_score(y_true, y_pred)
        corr_val, _ = pearsonr(y_true, y_pred)

        # Print training and validation loss for each epoch
        print(f'Epoch [{epoch+1}/{EPOCHS}], Train Loss: {train_loss/len(train_loader.dataset):.6f}, \
            Val Loss: {val_loss/len(val_loader.dataset):.6f}')
        
        # Print metrics for training and validation
        print(f'Train Metrics: MSE: {mse_train:.6f}, R^2: {r2_train:.6f}, Correlation: {corr_train:.6f}')
        print(f'Validation Metrics: MSE: {mse_val:.6f}, R^2: {r2_val:.6f}, Correlation: {corr_val:.6f}')
        
        # Save the best model based on validation loss
        if val_loss < best_val_loss:
            print("Update best model")
            check = 0
            print(f'patience count: {check}')
            best_val_loss = val_loss
            best_model_state = model.state_dict()
        else:
            check += 1
            print(f'patience count: {check}')
            if check >= patience: 
                print(f'Run out of patience on epoch {epoch}')
                break
        
        

    # Load the best model state for testing
    model.load_state_dict(best_model_state)

    # Testing
    model.eval()
    test_loss = 0.0
    y_true = []
    y_pred = []
    tst_last_attns = []
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)  # Move data to device
            outputs = model(inputs, None, None, None)
            tst_last_attns.append(model.attns)
            model.attns = None
            loss = criterion(outputs.reshape(-1, 1), targets.reshape(-1, 1))
            test_loss += loss.item() * inputs.size(0)
            y_true.extend(targets.reshape(-1, 1).detach().cpu().numpy())  
            y_pred.extend(outputs.reshape(-1, 1).detach().cpu().numpy())  # Move predictions back to CPU for evaluation

    with open(f'./attn_abla/itransformer20_h{PRED_LEN}', 'wb') as f:
        pickle.dump(tst_last_attns, f)

    # Calculate metrics for testing
    y_true = np.concatenate(y_true)
    y_pred = np.concatenate(y_pred)
    mse_test = mean_squared_error(y_true, y_pred)
    r2_test = r2_score(y_true, y_pred)
    corr_test, _ = pearsonr(y_true, y_pred)

    # Print metrics for testing
    print(f'Test Metrics: MSE: {mse_test:.4f}, R^2: {r2_test:.4f}, Correlation: {corr_test:.4f}')
