import os
import argparse
import numpy as np
import pandas as pd
import time

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset

from sklearn.metrics import mean_squared_error, r2_score
from scipy.stats import pearsonr
from tqdm import tqdm

from layers.conv_layer import Conv_Lob
from data_preprocessing.data_utils import load_data

from typing import Optional
from torch import Tensor
import math


# Constants
LEARNING_RATE = 0.0001
BATCH_SIZE = 64 * 2
EPOCHS = 15
WEIGHT_DECAY = 0
GAMMA = None
STEP_SIZE = None
SCHEDULER_FLAG = False
SCHEDULER_STEP_FLAG = False
LOSS_TYPE = 'MSE'
LOSS_CALCULATION_MP_FLAG = True


class ResBlock(nn.Module):
    def __init__(self, configs, out_c):
        super(ResBlock, self).__init__()

        self.temporal = nn.Sequential(
            nn.Linear(configs.seq_len, configs.d_model),
            nn.ReLU(),
            nn.Linear(configs.d_model, configs.seq_len),
            nn.Dropout(configs.dropout)
        )

        self.channel = nn.Sequential(
            nn.Linear(out_c, configs.d_model),
            nn.ReLU(),
            nn.Linear(configs.d_model, out_c),
            nn.Dropout(configs.dropout)
        )

    def forward(self, x):
        # x: [B, L, D]
        x = x + self.temporal(x.transpose(1, 2)).transpose(1, 2)
        x = x + self.channel(x)

        return x


class TimeMixer_conv(nn.Module):
    def __init__(self, configs, in_c=41, out_c=20, kernel=2, dilation=2, num_conv=5, conv_type='exp'):
        # use_future_temporal_feature=False, down_sampling_method=None, x_mark_enc is None, 
        super(TimeMixer_conv, self).__init__()
        self.task_name = 'long_term_forecast'
        self.layer = configs.e_layers

        self.conv = Conv_Lob(
            conv_type=conv_type, in_c=in_c, out_c=in_c, kernel=kernel,
            dilation=dilation, num_layers=num_conv, groups=in_c
        )

        self.model = nn.ModuleList([ResBlock(configs, in_c)
                                    for _ in range(configs.e_layers)])
        self.pred_len = configs.pred_len
        self.projection = nn.Linear(configs.seq_len, configs.pred_len)

    def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
        # B L C, C = 40

        # Pass the input tensor through a series of convolutional layers
        x_enc = self.conv(x_enc)

        # x: [B, L, D]
        for i in range(self.layer):
            x_enc = self.model[i](x_enc)
        enc_out = self.projection(x_enc.transpose(1, 2)).transpose(1, 2)

        return enc_out

    def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
        if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
            dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
            return dec_out[:, -self.pred_len:, -1]  # [B, L, D]
        else:
            raise ValueError('Only forecast tasks implemented yet')



class Config:
    def __init__(self, enc_in, seq_len, pred_len, e_layers, n_heads, d_model, d_ff, dropout, fc_dropout, head_dropout, individual, patch_len, stride, padding_patch, revin, affine, subtract_last, decomposition, kernel_size):
        self.enc_in = enc_in  # Number of input features
        self.seq_len = seq_len  # Input sequence length (lookback)
        self.pred_len = pred_len  # Prediction length (number of steps to predict)
        self.e_layers = e_layers  # Number of encoder layers in the transformer
        self.n_heads = n_heads  # Number of attention heads
        self.d_model = d_model  # Dimensionality of the model
        self.d_ff = d_ff  # Dimensionality of the feed-forward layer
        self.dropout = dropout  # Dropout rate
        self.fc_dropout = fc_dropout  # Dropout rate in the fully connected layer
        self.head_dropout = head_dropout  # Dropout rate at the output head
        self.individual = individual  # If set, handles model branching for individual tasks
        self.patch_len = patch_len  # Patch length
        self.stride = stride  # Stride for patching
        self.padding_patch = padding_patch  # Padding for patching
        self.revin = revin  # Whether to include reversible layers
        self.affine = affine  # Whether to use affine transformation in layers
        self.subtract_last = subtract_last  # Whether to subtract the last element in sequences
        self.decomposition = decomposition  # Whether to decompose input sequences
        self.kernel_size = kernel_size  # Kernel size for decomposition

class LeftPad1d(nn.Module):
    def __init__(self, left_pad):
        super(LeftPad1d, self).__init__()
        self.left_pad = left_pad

    def forward(self, x):
        return F.pad(x, (self.left_pad, 0))


def evaluate_predictions(y_true, y_pred):
    y_true_np = y_true.detach().cpu().numpy()
    y_pred_np = y_pred.detach().cpu().numpy()
    mse = mean_squared_error(y_true_np, y_pred_np)
    corr = np.corrcoef(y_true_np.flatten(), y_pred_np.flatten())[0, 1]
    r2 = r2_score(y_true_np, y_pred_np)
    return mse, corr, r2

class DictToClass:
    def __init__(self, dictionary):
        for key, value in dictionary.items():
            setattr(self, key, value)

def run_experiment(train_loader, val_loader, test_loader, num_features, PRED_LEN, LEARNING_RATE, SEED, WEIGHT_DECAY, BATCH_SIZE, EPOCHS, LOSS_TYPE, LOSS_CALCULATION_MP_FLAG, SCHEDULER_FLAG=False, SCHEDULER_STEP_FLAG=False, STEP_SIZE=None, GAMMA=None):
    # seed = 1
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    np.random.seed(SEED)
    
    config = dict(pred_len=1,
              task_name='long_term_forecast',
              freq='h',
              seq_len=100,
              label_len=48,
              expand=2,
              d_conv=4,
              top_k=5,
              num_kernels=6,
              enc_in=num_features+1,
              dec_in=1,
              c_out=1,
              d_model=16,
              n_heads=8,
              e_layers=2,
              d_layers=1,
              d_ff=2048,
              moving_avg=25,
              factor=3,
              distill=True,
              dropout=.1,
              embed='timeF',
              activation='gelu',
              output_attention=False,
              channel_independence=1,
              decomp_method='moving_avg',
              use_norm=1,
              down_sampling_layers=0,
              down_sampling_window=1,
              down_sampling_method=None,
              seg_len=48)

    model = TimeMixer_conv(DictToClass(config), in_c=num_features+1)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    model_name = "timemixer_conv"

    patience, check, best_val_loss = 2, 0, math.inf
    
    for epoch in range(EPOCHS):
        # Training
        model.train()
        train_loss = 0.0
        y_true = []
        y_pred = []
        for inputs, targets in train_loader:
            inputs, targets = inputs.to(device), targets.to(device)  # Move data to device
            # Forward pass
            outputs = model(inputs, None, None, None)
            loss = criterion(outputs.reshape(-1,1), targets.reshape(-1, 1))
            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item() * inputs.size(0)
            y_true.extend(targets.reshape(-1, 1).detach().cpu().numpy())  
            y_pred.extend(outputs.reshape(-1, 1).detach().cpu().numpy())
            
        y_true = np.concatenate(y_true)
        y_pred = np.concatenate(y_pred)
        mse_train = mean_squared_error(y_true, y_pred)
        r2_train = r2_score(y_true, y_pred)
        corr_train, _ = pearsonr(y_true, y_pred)
        
        # Validation
        model.eval()
        val_loss = 0.0
        y_true = []
        y_pred = []
        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.to(device), targets.to(device)  # Move data to device
                outputs = model(inputs, None, None, None)
                loss = criterion(outputs.reshape(-1, 1), targets.reshape(-1, 1))
                val_loss += loss.item() * inputs.size(0)
                y_true.extend(targets.reshape(-1, 1).detach().cpu().numpy())  
                y_pred.extend(outputs.reshape(-1, 1).detach().cpu().numpy())
        
        # Calculate metrics for validation
        y_true = np.concatenate(y_true)
        y_pred = np.concatenate(y_pred)
        mse_val = mean_squared_error(y_true, y_pred)
        r2_val = r2_score(y_true, y_pred)
        corr_val, _ = pearsonr(y_true, y_pred)

        # Print training and validation loss for each epoch
        print(f'Epoch [{epoch+1}/{EPOCHS}], Train Loss: {train_loss/len(train_loader.dataset):.6f}, \
            Val Loss: {val_loss/len(val_loader.dataset):.6f}')
        
        # Print metrics for training and validation
        print(f'Train Metrics: MSE: {mse_train:.6f}, R^2: {r2_train:.6f}, Correlation: {corr_train:.6f}')
        print(f'Validation Metrics: MSE: {mse_val:.6f}, R^2: {r2_val:.6f}, Correlation: {corr_val:.6f}')

        # Save the best model based on validation loss
        if val_loss < best_val_loss:
            print("Update best model")
            check = 0
            print(f'patience count: {check}')
            best_val_loss = val_loss
            best_model_state = model.state_dict()
        else:
            check += 1
            print(f'patience count: {check}')
            if check >= patience: 
                print(f'Run out of patience on epoch {epoch}')
                break   

    # Load the best model state for testing
    model.load_state_dict(best_model_state)

    # Testing
    model.eval()
    test_loss = 0.0
    y_true = []
    y_pred = []
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)  # Move data to device
            outputs = model(inputs, None, None, None)
            # outputs = outputs[:, [PRED_LEN-1], :]
            loss = criterion(outputs.reshape(-1, 1), targets.reshape(-1, 1))
            test_loss += loss.item() * inputs.size(0)
            y_true.extend(targets.reshape(-1, 1).detach().cpu().numpy())  
            y_pred.extend(outputs.reshape(-1, 1).detach().cpu().numpy())  # Move predictions back to CPU for evaluation

    # Calculate metrics for testing
    y_true = np.concatenate(y_true)
    y_pred = np.concatenate(y_pred)
    mse_test = mean_squared_error(y_true, y_pred)
    r2_test = r2_score(y_true, y_pred)
    corr_test, _ = pearsonr(y_true, y_pred)

    # Print metrics for testing
    print(f'Test Metrics: MSE: {mse_test:.4f}, R^2: {r2_test:.4f}, Correlation: {corr_test:.4f}')
