import os
import argparse
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import math

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset

import pytorch_lightning as pl
from sklearn.metrics import mean_squared_error, r2_score
from scipy.stats import pearsonr
from tqdm import tqdm
import sys


from layers.Transformer_EncDec import Encoder, EncoderLayer
from layers.SelfAttention_Family import FullAttention, AttentionLayer
from layers.Embed import DataEmbedding_inverted
from data_preprocessing.data_utils import load_data

# Constants
LEARNING_RATE = 0.0001
BATCH_SIZE = 64 * 2
EPOCHS = 15
WEIGHT_DECAY = 0
GAMMA = None
STEP_SIZE = None
SCHEDULER_FLAG = False
SCHEDULER_STEP_FLAG = False
LOSS_TYPE = 'MSE'
LOSS_CALCULATION_MP_FLAG = True
    
    

class iTransformer(nn.Module):
    """
    Paper link: https://arxiv.org/abs/2310.06625
    """

    def __init__(self, seq_len, pred_len, output_attention, use_norm, d_model, embed, freq, dropout, class_strategy, factor, n_heads, d_ff, activation, e_layers, out_variates=1):
        super().__init__()

        self.seq_len = seq_len
        self.pred_len = pred_len
        self.output_attention = output_attention
        self.use_norm = use_norm

        # Embedding
        self.enc_embedding = DataEmbedding_inverted(seq_len, d_model, embed, freq,
                                                    dropout)
        self.class_strategy = class_strategy
        # Encoder-only architecture
        self.encoder = Encoder(
            [
                EncoderLayer(
                    AttentionLayer(
                        FullAttention(False, factor, attention_dropout=dropout,
                                      output_attention=output_attention), d_model, n_heads),
                    d_model,
                    d_ff,
                    dropout=dropout,
                    activation=activation
                ) for l in range(e_layers)
            ],
            norm_layer=torch.nn.LayerNorm(d_model)
        )
        self.projector = nn.Linear(d_model, pred_len, bias=True)
        self.out_variates = out_variates
        print(f"out_variates: {self.out_variates}")

    def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):

        if self.use_norm:
            # Normalization from Non-stationary Transformer
            means = x_enc.mean(1, keepdim=True).detach()
            x_enc = x_enc - means
            stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
            x_enc = x_enc / stdev
        _, _, N = x_enc.shape # B L N
        # B: batch_size;    E: d_model;
        # L: seq_len;       S: pred_len;
        # N: number of variate (tokens), can also includes covariates

        # Embedding
        # B L N -> B N E                (B L N -> B L E in the vanilla Transformer)
        enc_out = self.enc_embedding(x_enc, x_mark_enc) # covariates (e.g timestamp) can be also embedded as tokens
        # B N E -> B N E                (B L E -> B L E in the vanilla Transformer)
        # the dimensions of embedded time series has been inverted, and then processed by native attn, layernorm and ffn modules
        enc_out, attns = self.encoder(enc_out, attn_mask=None)

        # B N E -> B N S -> B S N
        dec_out = self.projector(enc_out).permute(0, 2, 1)[:, :, :N] # filter the covariates
        if self.use_norm:
            # De-Normalization from Non-stationary Transformer
            dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
            dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))

        return dec_out


    def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
        dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
        return dec_out[:, -self.pred_len:, -self.out_variates:] # B, S, N


def evaluate_predictions(y_true, y_pred):
    y_true_np = y_true.detach().cpu().numpy()
    y_pred_np = y_pred.detach().cpu().numpy()
    mse = mean_squared_error(y_true_np, y_pred_np)
    corr = np.corrcoef(y_true_np.flatten(), y_pred_np.flatten())[0, 1]
    r2 = r2_score(y_true_np, y_pred_np)
    return mse, corr, r2



def run_experiment(train_loader, val_loader, test_loader, num_features, PRED_LEN, LEARNING_RATE, WEIGHT_DECAY, BATCH_SIZE, EPOCHS, LOSS_TYPE, LOSS_CALCULATION_MP_FLAG, SCHEDULER_FLAG=False, SCHEDULER_STEP_FLAG=False, STEP_SIZE=None, GAMMA=None):
    seed = 1
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    
    model = iTransformer(
            seq_len=100, pred_len=1, output_attention='store_true',
            use_norm=True, d_model=512, embed='timeF',
            freq='h', dropout=0.1, class_strategy='projection',
            factor=1 , n_heads=8,
            d_ff=2048, activation='gelu', e_layers=2
        )
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    model_name = "itransformer"
    patience, check = 3, 0
    
    for epoch in range(EPOCHS):
        # Training
        model.train()
        train_loss = 0.0
        y_true = []
        y_pred = []
        for inputs, targets in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{EPOCHS}"):
            inputs, targets = inputs.to(device), targets.to(device)  # Move data to device
            # Forward pass
            outputs = model(inputs, None, None, None)
            # loss = criterion(outputs.reshape(-1,1), targets.reshape(-1, 1))
            loss = criterion(outputs[:, :, num_features], targets.reshape(-1, 1))
            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item() * inputs.size(0)
            y_true.extend(targets.reshape(-1, 1).detach().cpu().numpy())  
            y_pred.extend(outputs[:, :, num_features].detach().cpu().numpy())
            
        y_true = np.concatenate(y_true)
        y_pred = np.concatenate(y_pred)
        mse_train = mean_squared_error(y_true, y_pred)
        r2_train = r2_score(y_true, y_pred)
        corr_train, _ = pearsonr(y_true, y_pred)
        
        # Validation
        model.eval()
        val_loss = 0.0
        y_true = []
        y_pred = []
        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.to(device), targets.to(device)  # Move data to device
                outputs = model(inputs, None, None, None)
                loss = criterion(outputs[:, :, num_features], targets.reshape(-1, 1))
                val_loss += loss.item() * inputs.size(0)
                y_true.extend(targets.reshape(-1, 1).detach().cpu().numpy())  
                y_pred.extend(outputs[:, :, num_features].detach().cpu().numpy())
        
        # Calculate metrics for validation
        y_true = np.concatenate(y_true)
        y_pred = np.concatenate(y_pred)
        mse_val = mean_squared_error(y_true, y_pred)
        r2_val = r2_score(y_true, y_pred)
        corr_val, _ = pearsonr(y_true, y_pred)

        # Print training and validation loss for each epoch
        print(f'Epoch [{epoch+1}/{EPOCHS}], Train Loss: {train_loss/len(train_loader.dataset):.6f}, \
            Val Loss: {val_loss/len(val_loader.dataset):.6f}')
        
        # Print metrics for training and validation
        print(f'Train Metrics: MSE: {mse_train:.6f}, R^2: {r2_train:.6f}, Correlation: {corr_train:.6f}')
        print(f'Validation Metrics: MSE: {mse_val:.6f}, R^2: {r2_val:.6f}, Correlation: {corr_val:.6f}')
        
        # Save the best model based on validation loss
        best_val_loss = math.inf
        if val_loss < best_val_loss:
            check = 0
            print("Update best model")
            best_val_loss = val_loss
            best_model_state = model.state_dict()
        else:
            check += 1
            if check >= patience: break
        

    # Load the best model state for testing
    model.load_state_dict(best_model_state)

    # Testing
    model.eval()
    test_loss = 0.0
    y_true = []
    y_pred = []
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)  # Move data to device
            outputs = model(inputs, None, None, None)
            loss = criterion(outputs[:, :, num_features], targets.reshape(-1, 1))
            test_loss += loss.item() * inputs.size(0)
            y_true.extend(targets.reshape(-1, 1).detach().cpu().numpy())  
            y_pred.extend(outputs[:, :, num_features].detach().cpu().numpy())  # Move predictions back to CPU for evaluation

    # Calculate metrics for testing
    y_true = np.concatenate(y_true)
    y_pred = np.concatenate(y_pred)
    mse_test = mean_squared_error(y_true, y_pred)
    r2_test = r2_score(y_true, y_pred)
    corr_test, _ = pearsonr(y_true, y_pred)

    # Print metrics for testing
    print(f'Test Metrics: MSE: {mse_test:.4f}, R^2: {r2_test:.4f}, Correlation: {corr_test:.4f}')

