from matplotlib import rc_params
import pandas as pd
from numpy import mod
import torch
import torch.optim as optim

from torch.random import seed
from torch.utils import data
from torch.utils.tensorboard import SummaryWriter


from deephfts.mats import coord_matrix, difference_matrix
from deephfts.modules import RBlock, RMatrix
from deephfts.losses import sblock_loss, rblock_loss
from deephfts.optimizers import sblock_optimizer, rmatrix_optimizer
from deephfts.utils.checkpoint import checkpoint

from deephfts.utils.metrics import smape
from deephfts.utils.datasets import dataset_helper
from deephfts.utils.plots import tensorboard_preds, visualize_TS, visualize_loss, tensorboard_TS, tensorboard_loss, tensorboard_graph
from deephfts.utils.zsl import zsl_file, zsl_facet, zsl_scatter
from deephfts.utils.cfi import cfi_file, cfi_scatter
from deephfts.model import SETR


def training_loop(
    path : str=None,
    dataset : str=None,
    batch_size : int = 1,
    input_dim : int = 1,
    window_size : int = 80,
    forecast_size : int = 20,
    s_losstype : str = "NLL",
    breakpoint : int = 1400,
    optimizer : str = "Adam",
    mode : str = "online",
    run_name: str = None,
    checkpoint_path : str = None,
    checkpoint_interval : int = 100,
    tensorboard: bool = True,
    width: int = None,
    height: int = None,
    zsl_path: str = None,
    cfi_path: str = None
    ):

    if(tensorboard): writer = SummaryWriter(comment=run_name)

    # initialize model
    model = SETR(block_size=[width, height], window_size=window_size, forecast_size=forecast_size)
    sblock = model.get_sblock()
    sblock = sblock.double()
    rmatrix = model.get_rmatrix()
    rmatrix.to_double()

    # intialize optimizers.
    s_optimizer = sblock_optimizer(sblock, method=optimizer)
    block_size = model.get_blocksize()
    width, height = block_size
    r_optimizers = rmatrix_optimizer(rmatrix=rmatrix, width=width, height=height, method=optimizer)

    # intialize coordinate matrix. 
    coord_matrix = model.get_coordmatrix()
    flat_coord = torch.flatten(coord_matrix, start_dim=0, end_dim=1)

    if(tensorboard): tensorboard_graph(writer, model, torch.randn(batch_size, input_dim, window_size), add_graph=None)
    train_loader, val_loader = dataset_helper(path=path, mode=mode, window_size=window_size, forecast_size=forecast_size)

    # initalize dataframe for zero-shot learning. 
    zsl_dfs = list()
    cfi_dfs = list()
    
    ## Training loop ------------------------------------------------------------------
    for index, (x, y) in enumerate(train_loader):
        zsl_dict = {}
        loss_dict = {}
        cfi_dict = {}

        zsl_dict['step'] = [i for i in range(window_size)]
        cfi_dict['step'] = [i for i in range(window_size + forecast_size)]
        if(tensorboard):
            tensorboard_TS(writer = writer, index = index, x = x, y = y, dataset=dataset)

        # Shaped for convolutions (Spatial)
        x = x.reshape(batch_size,input_dim,  window_size).double()
        zsl_dict['value'] = torch.flatten(x).tolist()
        zsl_dict['label'] = [index+1 for i in range(window_size)]
        cfi_dict['value'] = torch.flatten(x).tolist() + torch.flatten(y).tolist()
        cfi_dict['label'] = [index+1 for i in range(window_size + forecast_size)]

        # Shaped for time series (Temporal)
        t = x.reshape(batch_size, window_size, input_dim).double()
        y = y.reshape(batch_size,input_dim,  forecast_size).double()

        ### Training section. ----------------------------------------------
        # Initialize the optimizer for the S-block.
        s_optimizer.zero_grad()

        # Calculate S-block predictions.
        s_preds = sblock(x.double())

        # Calculate R-matrix Predictions and hidden layers
        rmatrix_hidden = model.rmatrix_hidden(t.double())
        rmatrix_preds = model.rmatrix_preds(x=t.double(), x_adj=rmatrix_hidden)

        # Optimize S-block
        y_reshape = y.expand(width, height, -1, -1)
        R_diff_rounded = difference_matrix(y=y_reshape, rmatrix_preds=rmatrix_preds)
        s_preds = s_preds.double()

        s_loss = sblock_loss(y=R_diff_rounded, y_pred=s_preds, losstype=s_losstype)
        loss_dict['Loss/S-block'] = s_loss.item()
        s_loss.backward()
        s_optimizer.step()

        ## Optimize R-block. 
        best_index = torch.argmax(s_preds) 
        ci_index = torch.argmin(s_preds)

        loss_dict['Int/Index'] = best_index
        best_Rblock = flat_coord[best_index].int().tolist()
        ci_RBlock = flat_coord[ci_index].int().tolist()

        zsl_dict['X'] = [best_Rblock[0] for i in range(window_size)]
        zsl_dict['Y'] = [best_Rblock[1] for i in range(window_size)]

        #cfi_dict['best_X'] = [best_Rblock[0] for i in range(window_size)]
        #cfi_dict['best_Y'] = [best_Rblock[1] for i in range(window_size)]
        #cfi_dict['ci_X'] = [ci_RBlock[0] for i in range(window_size)]
        #cfi_dict['ci_Y'] = [ci_RBlock[1] for i in range(window_size)]


        current_R_optimizer = r_optimizers[best_Rblock[0]][best_Rblock[1]]
        current_R_optimizer.zero_grad()
        current_R_preds = rmatrix_preds[best_Rblock[0],best_Rblock[1], :, : ]
        ci_R_preds = rmatrix_preds[ci_RBlock[0],ci_RBlock[1], :, : ]

        cfi_dict['best_preds'] = torch.flatten(x).tolist() + torch.flatten(current_R_preds).tolist()
        cfi_dict['ci_preds'] = torch.flatten(x).tolist() + torch.flatten(ci_R_preds).tolist()

        y = y.reshape(current_R_preds.shape)
        rloss, loss_k, loss_block = rblock_loss(pred_matrix=rmatrix_preds, y = y, losstype="forecasting", coordinates=best_Rblock)
        loss_dict['Loss/R-block'] = rloss.item()
        loss_dict['Loss/DWCL'] = loss_block.item()
        loss_dict['Loss/R-MSE'] = loss_k.item()
        rloss.backward()

        loss_dict['metric/SMAPE'] = smape(current_R_preds, y)
        
        # copy adjacency matrix weights between networks in the ensemble. 
        rmatrix.adjacency_lstm_copy(coordinate=best_Rblock)
        
        ### Logging section. ----------------------------------------------
        if(zsl_path):
            zsl_df  = pd.DataFrame(zsl_dict)
            zsl_dfs.append(zsl_df)
        
        if(cfi_path):
            cfi_df = pd.DataFrame(cfi_dict)
            cfi_dfs.append(cfi_df)

        ## Log via tensorbard
        if(tensorboard):
            tensorboard_loss(writer=writer, index=index+len(x), **loss_dict)

        ## Checkpoint model. 
        if(index % checkpoint_interval == 1):
            checkpoint(
                model_path = checkpoint_path + f'{index}.pth',
                s_block=sblock, s_optimizer=s_optimizer, 
                rmatrix=rmatrix, rmatrix_optimizers=r_optimizers, 
                width=width, height=height,
                save=True,load=False
            )
        
        ## Break at breakpoint. 
        if(index > breakpoint):
            break
    
    if(zsl_path):
        zsl_df = pd.concat(zsl_dfs)
        zsl_file(df = zsl_df, outfile=zsl_path)
        zsl_facet(in_path=zsl_path, out_path=zsl_path + '.facet.png')
        zsl_scatter(in_path=zsl_path, out_path=zsl_path + '.scatter.png')
    
    if(cfi_path):
        cfi_df = pd.concat(cfi_dfs)
        cfi_file(df = cfi_df, out_path = cfi_path)
        cfi_scatter(in_path = cfi_path, out_path = cfi_path+'.scatter.png')