import os
import sys
import copy
import importlib

import torch
import numpy as np
import pandas as pd

from os.path import join
from einops import repeat

# Setup
from configs import initialize_configs
from setup import init_experiment, get_updated_config, get_updated_config_from_argparse

# Data
from dataloaders import initialize_data_functions

# Model
from models.spacetime.network import SpaceTime, MultiHorizonSpaceTime, process_network_config
from models.nn.transforms import get_data_transforms

# Training
from train import run_epoch, train_model
from models import get_optimizer, get_scheduler
from loss import get_criterions

# Evaluation
from evaluate import get_evaluation_loaders, plot_forecast, get_dataset_evaluation
import matplotlib.pyplot as plt

# Experiment args
from args import argparse_args as args

# Logging and warnings
from utils.logging import make_csv, save_results
import warnings



def main():
    dataset_name = args.dataset
    
    # disable warnings if monash
    if 'monash' in args.dataset:
        warnings.filterwarnings("ignore")
    
    args.evaluate_horizon = copy.deepcopy(args.horizon)
    
    config   = initialize_configs(args)
    config   = get_updated_config_from_argparse(args, config)
    exp_name = init_experiment(config, args, prefix=f'{args.dataset}')
    
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    
    project_name = f'spacetime-d={dataset_name}-{"-".join(args.dataset_name.split("-")[1:])}'
    run_name = f'{args.network_config}-r={args.replicate}-tn={args.task_norm}-scale={args.scale}-bash={args.bash}-s={args.seed}'
    if not args.no_wandb:
        import wandb
        wandb.init(config=config, entity='mzhang', 
                   name=run_name,
                   project=project_name,
                   dir='./logs',
                   notes=f'd={args.dataset_name}-e={exp_name}')
        wandb.config.update(args)
    else:
        wandb = None
        
    # Save results
    results_path, results_columns = make_csv(args, save_dir='./logs', file_prefix='results', column_names=None)
    
    
    # LOAD DATA
    config.dataset['size'] = [args.lag, args.horizon, args.horizon]  # Teacher forcing hack
    load_data, visualize_data = initialize_data_functions(args)
    dataloaders, dataset = load_data(config.dataset, config.loader, args)
    print(f'-> Forecasting horizon: {args.horizon}')
    print(f'-> Forecasting lag:     {args.lag}')  
    # Update lag and horizon for the model
    
    print(config.network_config)
    config.network_config = process_network_config(config.network_config, args)
    
    
    config.dataset['size'] = [args.lag, args.evaluate_horizon, args.evaluate_horizon]  
    eval_loaders, _ = load_data(config.dataset, config.loader)
    eval_loaders = get_evaluation_loaders(eval_loaders, config)

    splits = ['train', 'val', 'test']
    dataloaders_by_split = {split: dataloaders[ix] 
                            for ix, split in enumerate(splits)}

    visualize_data(dataloaders, splits,
                   save=False, args=args, title=f'{args.dataset_name}')

    for dataloader in eval_loaders:
        print(f'horizon: {dataloader.dataset.forecast_horizon}')
        
    # Get input and output dimension    
    input_dim = 1
    output_dim = 1  # flatten for multivariate
    
    if args.features == 'M':
        config.network_config['multivariate'] = 1  # input_dim
    else:
        config.network_config['multivariate'] = 1

    # LOAD MODEL
    if args.replicate in [1]:
        model = SpaceTime(input_dim=input_dim, 
                          output_dim=output_dim,
                          **config.network_config)
    else:
        model = MultiHorizonSpaceTime(input_dim=input_dim, 
                                      output_dim=output_dim,
                                      **config.network_config)
        
    # Visualize differencing SSMs
    if config.network_config['layers']['layer1']['ssm']['n_diff'] > 0:
        try:
            print(model.nn[0].ssm_block.kernels.diff.c[:, :7, :7])
        except:
            print(model.nn[0].ssm_block.kernels.diff_error.c[:, :7, :7])
        
        
    input_transform, output_transform = get_data_transforms(args.task_norm,
                                                            args.horizon)
        
    optimizer = get_optimizer(model, config.optimizer)
    scheduler = get_scheduler(model, optimizer, config.scheduler)
    criterions = get_criterions(args)

    print(f'---------------------')
    print(f'Trainable parameters:')
    print(f'---------------------')
    for n, p in model.named_parameters():
        if p.requires_grad:
            print(n)
            
    # recurrent
    try:
        for layer in model.nn:
            for kernel_key, kernel in layer.ssm_block.kernels.items():
                print(f'{kernel_key} recurrent = {kernel.recurrent}')
    except:
        pass
            
    # TRAIN MODEL
    print(f'----------')
    print(f'Experiment:')
    print(f'----------')
    print(f'-> Experiment name: {args.experiment_name}')
    print(f'-> Checkpoint path: {args.checkpoint_path}')
    model = train_model(model, optimizer, scheduler, dataloaders_by_split, 
                        criterions, args.max_epochs, args,  # max epochs 50
                        input_transform, output_transform, wandb,
                        return_best=True)
    
    
    if args.replicate in [35]:
        args.multihorizon = 0
    
    # EVALUATE MODEL
    n_plots = len(splits) # train, val, test + freq. response
    fig, axes = plt.subplots(1, n_plots, figsize=(6.4 * n_plots, 4.8))

    
    split_metrics = plot_forecast(model, eval_loaders, splits, input_transform, 
                                  criterions['mse'], torch.device('cpu'),  
                                  forecast_rmse=True, 
                                  output_transform=output_transform,
                                  args=args, axes=axes, show=False, save=True)
    
    save_results(split_metrics, results_columns, args, results_path, method='SpaceTime')
    
    if not args.no_wandb:
        wandb.log({"forecast_plot": fig})
        wandb.log(split_metrics)


if __name__ == '__main__':
    main()
    
    
