import os
import sys
from pathlib import Path
import time
from datetime import datetime
import random
from dataclasses import dataclass
import inspect

import torch
import torch.nn as nn
from torch import optim
import numpy as np
import mlflow
import pandas as pd

@dataclass
class Configurations:
    d_model: int
    n_encoder_layers: int
    n_decoder_layers: int
    dropout: float
    model_type: type
    batch_size: int = 32
    n_heads: int = 8
    learning_rate: float = 0.0005

BASEPATH = Path(os.getcwd()).parent.parent
sys.path.append(BASEPATH.as_posix())

from src.modules.llm_config import LLMConfig
from src.models.sentinel import Sentinel
from src.data_provider.data_factory import data_provider
from src.utils.early_stopping import EarlyStopping
from src.utils.training_utils import dtype, validate, test, adjust_learning_rate_new, adjust_learning_rate
from src.logger.logger_utils import LoggerType, get_logger, LoggerConfig

DATASET_PATH = BASEPATH.parent / "datasets" / "timeseries_datasets" / "electricity"

seq_len = 96
data_name = 'custom'     # for dataloader
data_path = 'electricity.csv'
dataset_name = 'electricity'  # which dataset we are using
freq = 'h'
enc_in = 321


def train(
    batch_size,
    d_model,
    n_encoder_layers,
    n_decoder_layers,
    dropout,
    learning_rate,
    n_heads,
    pred_len,
    fix_seed, 
    model_type
):
    random.seed(fix_seed)
    torch.manual_seed(fix_seed)
    np.random.seed(fix_seed)

    batch_size = batch_size
    pred_len = pred_len

    config = LLMConfig(
        d_model=d_model,
        n_encoder_layers=n_encoder_layers,
        n_decoder_layers=n_decoder_layers,
        n_heads=n_heads,
        dropout=dropout,
        bias=False,
        enc_in=enc_in,
        patch_size=16,
        stride=8,
        seq_len=seq_len,
        pred_len=pred_len,
        freq=freq,
        data=data_name,
        root_path=DATASET_PATH.as_posix(),
        data_path=data_path,
        batch_size=batch_size,
        num_workers=10,
        train_epochs=100,
        patience=5,
        learning_rate=learning_rate,
        logging=LoggerType.NONE,
        warmup_epochs=0
    )
    print(f"starting new for {dataset_name} with {d_model}_{n_encoder_layers}_{n_decoder_layers}_{learning_rate}_{dropout}_{config.warmup_epochs}_{n_heads}_{batch_size}")

    train_dataset, train_dataloader = data_provider(config, 'train')
    valid_dataset, valid_dataloader = data_provider(config, 'val')
    test_dataset, test_dataloader = data_provider(config, 'test')

    datetime_now = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
    setting = f"model_{data_name}_{datetime_now}_{pred_len}_{d_model}_{fix_seed}"
    model_path = BASEPATH / "assets" / "models" / setting
    os.makedirs(model_path, exist_ok=True)
    checkpoints_path = model_path / config.checkpoints

    train_steps = len(train_dataloader)

    early_stopping = EarlyStopping(
        patience=config.patience,
        verbose=True,
        path=checkpoints_path.as_posix()
    )

    model = model_type(config)
    model.to(config.device)

    model_optim = optim.AdamW(
        model.parameters(),
        lr=config.learning_rate, 
        weight_decay=1e-2
    )

    criterion2 = nn.MSELoss()
    criterion = nn.L1Loss()

    logger_type = config.logging
    logger_config = LoggerConfig()
    logger = get_logger(logger_type, logger_config)

    if logger_type == LoggerType.MLFLOW:
        mlflow.start_run()

    try:
        for epoch in range(config.train_epochs):
            iter_count = 0
            train_loss_list = []

            model.train()
            epoch_time = time.time()
            for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(train_dataloader):
                iter_count += 1
                model_optim.zero_grad()
                batch_x = batch_x.to(config.device, dtype=dtype(config.device))
                batch_y = batch_y.to(config.device, dtype=dtype(config.device))

                if 'PEMS' in config.data or 'Solar' in config.data or 'ETT' in config.data:
                    batch_x_mark = None
                    batch_y_mark = None
                else:
                    batch_x_mark = batch_x_mark.to(config.device, dtype=dtype(config.device))
                    batch_y_mark = batch_y_mark.to(config.device, dtype=dtype(config.device))

                outputs = model(batch_x, batch_x_mark)

                outputs = outputs[:, -config.pred_len:, :].to(dtype=dtype(config.device))
                batch_y = batch_y[:, -config.pred_len:, :].to(config.device, dtype=dtype(config.device))

                loss = criterion(outputs, batch_y)
                train_loss_list.append(loss.item())
                
                
                if (i + 1) % 100 == 0:
                    print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format(i + 1, epoch + 1, loss.item()))
                    iter_count = 0
                
                loss.backward()
                model_optim.step()

            print("Epoch: {} cost time: {}".format(epoch + 1, time.time() - epoch_time))
            train_loss = np.average(train_loss_list)
            vali_loss = validate(model, valid_dataloader, criterion, config, test=True)
            test_loss = validate(model, test_dataloader, criterion2, config)

            print(
                f"Epoch: {epoch + 1}, "
                f"Steps: {train_steps} | "
                f"Train Loss: {train_loss:.7f} "
                f"Vali Loss: {vali_loss:.7f} "
                f"Test Loss: {test_loss:.7f}"
            )

            early_stopping(vali_loss, model)
            if early_stopping.early_stop:
                print("Early stopping")
                break

            adjust_learning_rate_new(model_optim, epoch + 1, config)

        best_model_path = model_path / 'checkpoint.pth'
        model.load_state_dict(torch.load(best_model_path.as_posix()))
        mse, mae, rmse = test(model, test_dataloader, config)

        test_results = {
            'mse': mse,
            'mae': mae,
            'rmse': rmse
        }

        logger.log_metrics(test_results)

    finally:
        if logger == LoggerType.MLFLOW:
            mlflow.end_run()
    return mse, mae

if __name__ == '__main__':

    pred_len_list = [96, 192, 336, 720]
    seeds = [
        2023,
        42,
        52,
        66,
        48,
        20,
        7,
        37,
        17,
        373,
        2349,
        2395,
        2451,
        939,
        1545,
        899,
        1234,
    ]

    configurations = [
        Configurations(
            d_model=64,
            n_encoder_layers=3,
            n_decoder_layers=4,
            dropout=0.3,
            learning_rate=0.0005,
            model_type=Sentinel
        ),
    ]

    results_dict_list = []
    
    counter = 0
    for configuration in configurations:
        for pred_len in pred_len_list:
            for fix_seed in seeds:
                counter += 1
                                
    print(f"running {counter} different configurations")
    
    current_counter = 0

    datetime_now = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
    results_file_name = f"runs_{dataset_name}_{datetime_now}.xlsx"
    for configuration in configurations:
        for pred_len in pred_len_list:
            for fix_seed in seeds:
                current_counter += 1
                print(f"starting run {current_counter} out of {counter}")
                
                mse, mae = train(
                    configuration.batch_size,
                    configuration.d_model,
                    configuration.n_encoder_layers,
                    configuration.n_decoder_layers,
                    configuration.dropout,
                    configuration.learning_rate,
                    configuration.n_heads,
                    pred_len,
                    fix_seed,
                    configuration.model_type
                )

                
                model_type_str = os.path.basename(inspect.getfile(configuration.model_type))
                results_dict_list.append({
                    'model_type': model_type_str,
                    'batch_size': configuration.batch_size,
                    'seq_len': seq_len,
                    'pred_len': pred_len,
                    'fix_seed': fix_seed,
                    'd_model': configuration.d_model,
                    'n_encoder_layers': configuration.n_encoder_layers,
                    'n_decoder_layers': configuration.n_decoder_layers,
                    'n_heads': configuration.n_heads,
                    'dropout': configuration.dropout,
                    'learning_rate': configuration.learning_rate,
                    'mse': mse,
                    'mae': mae
                })
                                    
                pd.DataFrame.from_dict(results_dict_list).to_excel(results_file_name)

    pd.DataFrame.from_dict(results_dict_list).to_excel(results_file_name)