import os
import argparse
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset

import pytorch_lightning as pl
from sklearn.metrics import mean_squared_error, r2_score
from scipy.stats import pearsonr
from tqdm import tqdm

# Constants
LEARNING_RATE = 0.0001
BATCH_SIZE = 64 * 2
EPOCHS = 15
WEIGHT_DECAY = 0
GAMMA = None
STEP_SIZE = None
SCHEDULER_FLAG = False
SCHEDULER_STEP_FLAG = False
LOSS_TYPE = 'MSE'
LOSS_CALCULATION_MP_FLAG = True

class ResBlock(nn.Module):
    def __init__(self, configs):
        super(ResBlock, self).__init__()

        self.temporal = nn.Sequential(
            nn.Linear(configs.seq_len, configs.d_model),
            nn.ReLU(),
            nn.Linear(configs.d_model, configs.seq_len),
            nn.Dropout(configs.dropout)
        )

        self.channel = nn.Sequential(
            nn.Linear(configs.enc_in, configs.d_model),
            nn.ReLU(),
            nn.Linear(configs.d_model, configs.enc_in),
            nn.Dropout(configs.dropout)
        )

    def forward(self, x):
        # x: [B, L, D]
        x = x + self.temporal(x.transpose(1, 2)).transpose(1, 2)
        x = x + self.channel(x)

        return x


class TimeMixer(nn.Module):
    def __init__(self, configs, out_variates=1):
        super(TimeMixer, self).__init__()
        self.task_name = 'long_term_forecast'
        self.layer = configs.e_layers
        self.model = nn.ModuleList([ResBlock(configs)
                                    for _ in range(configs.e_layers)])
        self.pred_len = configs.pred_len
        self.projection = nn.Linear(configs.seq_len, configs.pred_len)
        self.out_variates = out_variates
    def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):

        # x: [B, L, D]
        for i in range(self.layer):
            x_enc = self.model[i](x_enc)
        enc_out = self.projection(x_enc.transpose(1, 2)).transpose(1, 2)

        return enc_out

    def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
        if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
            dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
            return dec_out[:, -self.pred_len:, -self.out_variates:]  # [B, L, D]
        else:
            raise ValueError('Only forecast tasks implemented yet')

class DictToClass:
    def __init__(self, dictionary):
        for key, value in dictionary.items():
            setattr(self, key, value)


def evaluate_predictions(y_true, y_pred):
    y_true_np = y_true.detach().cpu().numpy()
    y_pred_np = y_pred.detach().cpu().numpy()
    mse = mean_squared_error(y_true_np, y_pred_np)
    corr = np.corrcoef(y_true_np.flatten(), y_pred_np.flatten())[0, 1]
    r2 = r2_score(y_true_np, y_pred_np)
    return mse, corr, r2



def run_experiment(train_loader, val_loader, test_loader, num_features, PRED_LEN, LEARNING_RATE, WEIGHT_DECAY, BATCH_SIZE, EPOCHS, LOSS_TYPE, LOSS_CALCULATION_MP_FLAG, SCHEDULER_FLAG=False, SCHEDULER_STEP_FLAG=False, STEP_SIZE=None, GAMMA=None):
    seed = 1
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    
    config = dict(pred_len=1,
              task_name='long_term_forecast',
              freq='h',
              seq_len=100,
              label_len=48,
              expand=2,
              d_conv=4,
              top_k=5,
              num_kernels=6,
              enc_in=train_X.shape[1],
              dec_in=1,
              c_out=1,
              d_model=16,
              n_heads=8,
              e_layers=2,
              d_layers=1,
              d_ff=2048,
              moving_avg=25,
              factor=3,
              distill=True,
              dropout=.1,
              embed='timeF',
              activation='gelu',
              output_attention=False,
              channel_independence=1,
              decomp_method='moving_avg',
              use_norm=1,
              down_sampling_layers=0,
              down_sampling_window=1,
              down_sampling_method=None,
              seg_len=48)

    model = TimeMixer(DictToClass(config))
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    model_name = "itransformer"
    
    for epoch in range(EPOCHS):
        # Training
        model.train()
        train_loss = 0.0
        y_true = []
        y_pred = []
        for inputs, targets in train_loader:
            inputs, targets = inputs.to(device), targets.to(device)  # Move data to device
            # Forward pass
            outputs = model(x_enc=inputs, x_mark_enc=None, x_dec=None, x_mark_dec=None)
            loss = criterion(outputs[:, :, num_features], targets.reshape(-1, 1))
            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item() * inputs.size(0)
            y_true.extend(targets.reshape(-1, 1).detach().cpu().numpy())  
            y_pred.extend(outputs[:, :, num_features].detach().cpu().numpy())
            
        y_true = np.concatenate(y_true)
        y_pred = np.concatenate(y_pred)
        mse_train = mean_squared_error(y_true, y_pred)
        r2_train = r2_score(y_true, y_pred)
        corr_train, _ = pearsonr(y_true, y_pred)
        
        # Validation
        model.eval()
        val_loss = 0.0
        y_true = []
        y_pred = []
        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.to(device), targets.to(device)  # Move data to device
                outputs = model(x_enc=inputs, x_mark_enc=None, x_dec=None, x_mark_dec=None)
                loss = criterion(outputs[:, :, num_features], targets.reshape(-1, 1))
                val_loss += loss.item() * inputs.size(0)
                y_true.extend(targets.reshape(-1, 1).detach().cpu().numpy())  
                y_pred.extend(outputs[:, :, num_features].detach().cpu().numpy())
        
        # Calculate metrics for validation
        y_true = np.concatenate(y_true)
        y_pred = np.concatenate(y_pred)
        mse_val = mean_squared_error(y_true, y_pred)
        r2_val = r2_score(y_true, y_pred)
        corr_val, _ = pearsonr(y_true, y_pred)
        
        # Save the best model based on validation loss
        if val_loss < best_val_loss:
            print("Update best model")
            best_val_loss = val_loss
            best_model_state = model.state_dict()
        
        # Print training and validation loss for each epoch
        print(f'Epoch [{epoch+1}/{EPOCHS}], Train Loss: {train_loss/len(train_loader.dataset):.6f}, \
            Val Loss: {val_loss/len(val_loader.dataset):.6f}')
        
        # Print metrics for training and validation
        print(f'Train Metrics: MSE: {mse_train:.6f}, R^2: {r2_train:.6f}, Correlation: {corr_train:.6f}')
        print(f'Validation Metrics: MSE: {mse_val:.6f}, R^2: {r2_val:.6f}, Correlation: {corr_val:.6f}')

    # Load the best model state for testing
    model.load_state_dict(best_model_state)

    # Testing
    model.eval()
    test_loss = 0.0
    y_true = []
    y_pred = []
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)  # Move data to device
            outputs = model(x_enc=inputs, x_mark_enc=None, x_dec=None, x_mark_dec=None)
            loss = criterion(outputs[:, :, num_features], targets.reshape(-1, 1))
            test_loss += loss.item() * inputs.size(0)
            y_true.extend(targets.reshape(-1, 1).detach().cpu().numpy())  
            y_pred.extend(outputs[:, :, num_features].detach().cpu().numpy())  # Move predictions back to CPU for evaluation

    # Calculate metrics for testing
    y_true = np.concatenate(y_true)
    y_pred = np.concatenate(y_pred)
    mse_test = mean_squared_error(y_true, y_pred)
    r2_test = r2_score(y_true, y_pred)
    corr_test, _ = pearsonr(y_true, y_pred)

    # Print metrics for testing
    print(f'Test Metrics: MSE: {mse_test:.4f}, R^2: {r2_test:.4f}, Correlation: {corr_test:.4f}')
