from dataloader import get_dataset
from utils import GlobalConfig
from .exp_basic import ExpBasic
import torch.nn as nn
import os
from utils.tools import EarlyStopping
import time
import numpy as np
from torch import optim
import torch
from autoaugment import AutoAugmentBasic

# TODO
# loss function

class M4Meta:
    seasonal_patterns = ['Yearly', 'Quarterly', 'Monthly', 'Weekly', 'Daily', 'Hourly']
    horizons = [6, 8, 18, 13, 14, 48]
    frequencies = [1, 4, 12, 1, 1, 24]
    horizons_map = {
        'Yearly': 6,
        'Quarterly': 8,
        'Monthly': 18,
        'Weekly': 13,
        'Daily': 14,
        'Hourly': 48
    } # different predict length

    frequency_map = {
        'Yearly': 1,
        'Quarterly': 4,
        'Monthly': 12,
        'Weekly': 1,
        'Daily': 1,
        'Hourly': 24
    }
    
    history_size = {
        'Yearly': 1.5,
        'Quarterly': 1.5,
        'Monthly': 1.5,
        'Weekly': 10,
        'Daily': 10,
        'Hourly': 10
    }

# config.args:
# data = 'm4'
# pred_len
# seasonal_patterns
# seq_len


class ExpShortTermForecasting(ExpBasic):

    def _build_model(self) -> AutoAugmentBasic:
        if self.config.args.data == 'm4':
            self.config.args.pred_len = M4Meta.horizons_map[self.config.args.seasonal_patterns]
            self.config.args.seq_len = 2 * self.config.args.pred_len # input_len = 2 * pred_len
            self.config.args.label_len = getattr(self.config.args, "label_len", self.config.args.pred_len)
            self.config.args.frequency_map = M4Meta.frequency_map[self.config.args.seasonal_patterns]
        
        train_data, train_loader = self._get_data(flag='TRAIN', load_as='TRAIN')
        test_data, test_loader = self._get_data(flag='TEST', load_as='TEST')

        # model init
        model:AutoAugmentBasic = self.model_class(self.config).float().to(self.device)
        return model

    def _get_data(self, flag: str, load_as:str):
        if load_as not in self.loaded_data:
            dataset, data_loader = get_dataset(self.config, flag)
            self.loaded_data[load_as]=(dataset, data_loader)
        else:
            dataset, data_loader = self.loaded_data[load_as]
        return dataset, data_loader

    def _select_optimizer(self):
        model_optim = optim.Adam(self.model.parameters(), lr=self.config.args.learning_rate)
        return model_optim
    
    def _select_criterion(self):
        name = getattr(self.config.args, "forecasting_loss", "MSE").upper()
        if name == "MSE":
            return nn.MSELoss()

     def train(self):
        train_data, train_loader = self._get_data(flag='TRAIN', load_as='TRAIN')
        test_data, test_loader = self._get_data(flag='TEST', load_as='TEST')
        vali_data, vali_loader = self._get_data(flag='TEST', load_as='TEST')

        checkpoint_path = self.config.get_checkpoint_path()
        if not os.path.exists(os.path.dirname(checkpoint_path)):
            os.makedirs(os.path.dirname(checkpoint_path))
        
        time_now = time.time()
        train_steps = len(train_loader)
        early_stopping = EarlyStopping(patience=self.config.args.patience, verbose=True)

        model_optim = self._select_optimizer()
        criterion = self._select_criterion()

        for epoch in range(self.config.args.train_epochs):
            print("[Model State]", end="")
            print(self.model.summarize_state())
            iter_count = 0
            train_loss = []

            self.model.train()
            epoch_time = time.time()

            for i, (batch_x, batch_y, batch_f, batch_masks) in enumerate(train_loader):
                iter_count += 1
                model_optim.zero_grad()

                batch_x = batch_x.float().to(self.device) # [B, C, Lx]
                batch_y = batch_y.float().to(self.device) # [B, C, Ly]
                batch_f = batch_f.float().to(self.device)
                batch_masks = batch_masks.to(self.device)

                




