from typing import Union, Tuple

import numpy as np
import torch
from tqdm import tqdm
import math

from src.utils.metric import metric


# data loading
def get_batch(
        data: torch.tensor,
        block_size: int,
        batch_size: int,
        device: str
) -> (torch.tensor, torch.tensor):
    """
    generate a small batch of data of inputs x and targets y
    :param data:
    :param block_size:
    :param batch_size:
    :param device:
    :return:
    """
    ix = torch.randint(data.shape[0] - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return __setup_x_y(device, x, y)


def get_batch_past_futures(
        data: torch.tensor,
        block_size: int,
        prediction_length: int,
        batch_size: int,
        device: str
) -> (torch.tensor, torch.tensor):
    """
    generate a small batch of data of inputs x and targets y
    :param data:
    :param block_size:
    :param prediction_length:
    :param batch_size:
    :param device:
    :return:
    """
    ix = torch.randint(data.shape[0] - block_size - prediction_length, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+block_size:i+block_size+prediction_length] for i in ix])
    return __setup_x_y(device, x, y)


def __setup_x_y(device, x, y):
    if device == 'cuda':
        # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
        x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
    else:
        x, y = x.to(device), y.to(device)
    if len(x.shape) == 2:
        x = x.unsqueeze(2)
        y = y.unsqueeze(2)
    return x, y


def train_test_val_split(
        data: Union[np.ndarray, torch.tensor],
        train_size: float = 0.7,
        test_size: float = 0.15,
        validation_split: bool = True
) -> Union[Tuple[torch.tensor, torch.tensor, torch.tensor], Tuple[torch.tensor, torch.tensor]]:
    """
    split the data into train, test and validation sets
    :param data:
    :param train_size:
    :param test_size:
    :param validation_split:
    :return:
    """
    if isinstance(data, np.ndarray):
        data = torch.from_numpy(data)

    train_data_len = int(train_size * len(data))
    train_data = data[:train_data_len]
    test_data = data[train_data_len:]

    if validation_split or test_size + train_size < 1:
        test_val_data = test_data
        test_size = test_size / (1 - train_size)
        val_size = 1 - test_size

        valid_data_len = int(val_size * len(test_val_data))
        val_data = test_val_data[:valid_data_len]
        test_data = test_val_data[valid_data_len:]

        return train_data, val_data, test_data
    else:
        return train_data, test_data


def get_device():
    if torch.cuda.is_available():
        device = 'cuda'
    else:
        device = 'cpu'

    return device


def dtype(device):
    dtype_ = torch.float32
    return dtype_


def to(tensor: torch.Tensor, device: str) -> torch.Tensor:
    """
    Move tensor to device
    :param tensor: tensor to move
    :param device: device to move tensor to
    :return: tensor on device
    """
    return tensor.to(device=device, dtype=dtype(device))



@torch.no_grad()
def validate(model, vali_loader, criterion, config, test=True):
    total_loss = []
    model.eval()
    for batch_x, batch_y, batch_x_mark, batch_y_mark in vali_loader:
        batch_x = to(batch_x, config.device)
        batch_y = to(batch_y, config.device)

        if 'PEMS' in config.data or 'Solar' in config.data or 'ETT' in config.data:
            batch_x_mark = None
            batch_y_mark = None
        else:
            batch_x_mark = to(batch_x_mark, config.device)
            batch_y_mark = to(batch_y_mark, config.device)

        outputs = model(batch_x, batch_x_mark)

        outputs = outputs[:, -config.pred_len:, :].to(dtype=dtype(config.device))
        batch_y = batch_y[:, -config.pred_len:, :].to(dtype=dtype(config.device))


        #if not test:
        #    ratio = np.array([max(1/np.sqrt(i+1),0.0) for i in range(config.pred_len)])
        #    ratio = torch.tensor(ratio).unsqueeze(-1).to('cuda')
        #    outputs = outputs*ratio
        #    batch_y = batch_y*ratio

        pred = outputs.detach().cpu()
        true = batch_y.detach().cpu()

        loss = criterion(pred, true)

        total_loss.append(loss)
    total_loss = np.average(total_loss)
    model.train()
    return total_loss


@torch.no_grad()
def test(model, test_loader, config):
    preds = []
    trues = []

    model.eval()

    for batch_x, batch_y, batch_x_mark, batch_y_mark in tqdm(test_loader):
        batch_x = to(batch_x, config.device)
        batch_y = to(batch_y, config.device)

        if 'PEMS' in config.data or 'Solar' in config.data or 'ETT' in config.data:
            batch_x_mark = None
            batch_y_mark = None
        else:
            batch_x_mark = to(batch_x_mark, config.device)
            batch_y_mark = to(batch_y_mark, config.device)

        outputs = model(batch_x, batch_x_mark)

        outputs = outputs[:, -config.pred_len:, :]
        batch_y = batch_y[:, -config.pred_len:, :].to(dtype=dtype(config.device))

        pred = outputs.detach().cpu().numpy()
        true = batch_y.detach().cpu().numpy()

        preds.append(pred)
        trues.append(true)

    preds = np.array(preds)
    trues = np.array(trues)

    print('test shape:', preds.shape, trues.shape)
    preds = preds.reshape(-1, preds.shape[-2], preds.shape[-1])
    trues = trues.reshape(-1, trues.shape[-2], trues.shape[-1])
    print('test shape:', preds.shape, trues.shape)

    mae, mse, rmse, mape, mspe = metric(preds, trues)
    print('mae:{:.4f}, mse:{:.4f}, rmse:{:.4f}'.format(mae, mse, rmse))

    return mse, mae, rmse


def adjust_learning_rate(optimizer, epoch, config):
    # lr = args.learning_rate * (0.2 ** (epoch // 2))

    if epoch < config.warmup_epochs:
        print("warmup")
        lr = config.learning_rate * epoch / config.warmup_epochs 
    elif config.lradj == 'type1':
        lr = config.learning_rate * (0.5 ** (((epoch-config.warmup_epochs) - 1) // 1))
    else: 
        raise NotImplementedError("lr not implemented")
        

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    print('Updating learning rate to {}'.format(lr))
        
def adjust_learning_rate_new(optimizer, epoch, config):
    """Decay the learning rate with half-cycle cosine after warmup"""
    min_lr = 0
    if epoch < config.warmup_epochs:
        lr = config.learning_rate * epoch / config.warmup_epochs 
    else:
        lr = min_lr+ (config.learning_rate - min_lr) * 0.5 * (1. + math.cos(math.pi * (epoch - config.warmup_epochs) / (config.train_epochs - config.warmup_epochs)))
                
    for param_group in optimizer.param_groups:
        if "lr_scale" in param_group:
            param_group["lr"] = lr * param_group["lr_scale"]
        else:
            param_group["lr"] = lr
    print(f'Updating learning rate to {lr:.7f}')
    return lr