import os
import argparse
import numpy as np
import pandas as pd
import math
import time
import torch
import torch.nn as nn
from sklearn.metrics import mean_squared_error, r2_score
from scipy.stats import pearsonr


from mprf.data_preprocessing.model_utils import load_model, load_optimizer, retrieve_params
from mprf.data_preprocessing.data_utils import load_data


def evaluate_predictions(y_true, y_pred):
    y_true_np = y_true.detach().cpu().numpy()
    y_pred_np = y_pred.detach().cpu().numpy()
    mse = mean_squared_error(y_true_np, y_pred_np)
    corr = np.corrcoef(y_true_np.flatten(), y_pred_np.flatten())[0, 1]
    r2 = r2_score(y_true_np, y_pred_np)
    return mse, corr, r2


def safe_correlation(y_true, y_pred):
    # Print shapes and a few sample values
    print(f"Input shapes - y_pred: {y_pred.shape}, y_true: {y_true.shape}")
    print(f"Sample predictions (first 5): {y_pred.flatten()[:5]}")
    print(f"Sample targets (first 5): {y_true.flatten()[:5]}")
    
    # Check for NaN values and print count if found
    pred_nans = np.isnan(y_pred).sum()
    true_nans = np.isnan(y_true).sum()
    if pred_nans > 0 or true_nans > 0:
        print(f"NaN count - predictions: {pred_nans}, targets: {true_nans}")
        return float('nan')
    
    # Check for infinite values
    pred_infs = np.isinf(y_pred).sum()
    true_infs = np.isinf(y_true).sum()
    if pred_infs > 0 or true_infs > 0:
        print(f"Infinite value count - predictions: {pred_infs}, targets: {true_infs}")
        return float('nan')
    
    # Check for zero variance and print variances
    pred_var = np.var(y_pred)
    target_var = np.var(y_true)
    print(f"Variances - predictions: {pred_var}, targets: {target_var}")
    
    if pred_var == 0 or target_var == 0:
        print("Zero variance detected")
        return float('nan')
    
    # Calculate correlation with error handling
    try:
        correlation, p_value = pearsonr(y_pred.flatten(), y_true.flatten())
        print(f"Calculated correlation: {correlation}, p-value: {p_value}")
        return float(correlation)
    except Exception as e:
        print(f"Error calculating correlation: {e}")
        return float('nan')
    
def safe_correlation(y_true, y_pred):
    return float(pearsonr(y_true.flatten(), y_pred.flatten())[0])



def run_experiment(model, optimizer, train_loader, val_loader, test_loader, seed,
                   LEARNING_RATE, patience, file_name, WEIGHT_DECAY, EPOCHS, LOSS_TYPE=None, 
                   LOSS_CALCULATION_MP_FLAG=None, SCHEDULER_FLAG=False, SCHEDULER_STEP_FLAG=False, 
                   STEP_SIZE=None, GAMMA=None):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if torch.cuda.device_count() > 1:
        print(f'GPUs available: {torch.cuda.device_count()}')
        model = nn.DataParallel(model)
    model = model.to(device)
    criterion = nn.MSELoss()
    best_val_loss, check = math.inf, 0
    best_model_state = model.state_dict()
    best_epoch = 0
    
    for epoch in range(1, EPOCHS+1):
        # Training
        model.train()
        train_loss = 0.0
        y_true = []
        y_pred = []
        
        batch_checkpoint = len(train_loader) // 10
        for batch, (inputs, targets) in enumerate(train_loader):
            # print(f"starting training ")
            if batch % batch_checkpoint == 0: print(f'Progress through epoch: {batch // batch_checkpoint} / 10')
            inputs, targets = inputs.to(device), targets.to(device)  # Move data to device
            # Forward pass
            outputs = model(inputs, None, None, None)
            loss = criterion(outputs.reshape(-1,1), targets.reshape(-1, 1))
            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            
            train_loss += loss.item() * inputs.size(0)
            y_true.extend(targets.reshape(-1, 1).detach().cpu().numpy())  
            y_pred.extend(outputs.reshape(-1, 1).detach().cpu().numpy())
        
        if len(np.unique(y_pred)) == 1:
            print(f'Constant array, value: {np.unique(y_pred)[0]}')
        y_true = np.concatenate(y_true)
        y_pred = np.concatenate(y_pred)
        mse_train = mean_squared_error(y_true, y_pred)
        r2_train = r2_score(y_true, y_pred)
        corr_train = safe_correlation(y_true, y_pred)
        
        # Validation
        model.eval()
        val_loss = 0.0
        y_true = []
        y_pred = []
        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.to(device), targets.to(device)  # Move data to device
                outputs = model(inputs, None, None, None)
                loss = criterion(outputs.reshape(-1, 1), targets.reshape(-1, 1))
                val_loss += loss.item() * inputs.size(0)
                y_true.extend(targets.reshape(-1, 1).detach().cpu().numpy())  
                y_pred.extend(outputs.reshape(-1, 1).detach().cpu().numpy())
        
        # Calculate metrics for validation
        if len(np.unique(y_pred)) == 1:
            print(f'Constant array, value: {np.unique(y_pred)[0]}')
        y_true = np.concatenate(y_true)
        y_pred = np.concatenate(y_pred)
        mse_val = mean_squared_error(y_true, y_pred)
        r2_val = r2_score(y_true, y_pred)
        corr_val = safe_correlation(y_true, y_pred)

        # Print training and validation loss for each epoch
        print(f'Epoch [{epoch}/{EPOCHS}], Train Loss: {train_loss/len(train_loader.dataset):.6f}, \
            Val Loss: {val_loss/len(val_loader.dataset):.6f}')
        
        # Print metrics for training and validation
        print(f'Train Metrics: MSE: {mse_train:.6f}, Correlation: {corr_train:.6f}, R^2: {r2_train:.6f}')
        print(f'Validation Metrics: MSE: {mse_val:.6f}, Correlation: {corr_val:.6f}, R^2: {r2_val:.6f}')
        
        # Save the best model based on validation loss
        if val_loss < best_val_loss:
            print("Update best model")
            check = 0
            print(f'patience count: {check}')
            best_val_loss, best_epoch = val_loss, epoch
            best_model_state = model.state_dict()
        else:
            check += 1
            print(f'patience count: {check}')
            if check >= patience: break

    # Load the best model state for testing
    print(f'Best model at epoch {best_epoch+1}')
    model.load_state_dict(best_model_state)

    # Testing
    model.eval()
    test_loss = 0.0
    y_true = []
    y_pred = []
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)  # Move data to device
            outputs = model(inputs, None, None, None)
            loss = criterion(outputs.reshape(-1, 1), targets.reshape(-1, 1))
            test_loss += loss.item() * inputs.size(0)
            y_true.extend(targets.reshape(-1, 1).detach().cpu().numpy())  
            y_pred.extend(outputs.reshape(-1, 1).detach().cpu().numpy())  # Move predictions back to CPU for evaluation

    # Calculate metrics for testing
    if len(np.unique(y_pred)) == 1:
            print(f'Constant array, value: {np.unique(y_pred)[0]}')
    y_true = np.concatenate(y_true)
    y_pred = np.concatenate(y_pred)
    mse_test = mean_squared_error(y_true, y_pred)
    r2_test = r2_score(y_true, y_pred)
    corr_test = safe_correlation(y_true, y_pred)

    # Print metrics for testing
    print(f'Test Metrics: MSE: {mse_test:.4f}, Correlation: {corr_test:.4f}, R^2: {r2_test:.4f}')
    return mse_test, corr_test, r2_test

def list_of_int(arg):
    return list(map(int, arg.split(',')))


def main():
    np.set_printoptions(precision=5)
    torch.set_printoptions(precision=5)
    parser = argparse.ArgumentParser(description="Run LOB prediction experiment")
    parser.add_argument("--dataset", type=str, required=True, help="Dataset to use (FI or CHF)")
    parser.add_argument("--rolling_norm", action='store_true')
    parser.add_argument("--data_dir", type=str, help="Directory containing the data")
    parser.add_argument("--model", type=str, required=True, help='Model')
    parser.add_argument("--num_features", type=int, default=40, help="Number of features to use")
    parser.add_argument("--lookback", type=int, default=None, help="Lookback period")
    parser.add_argument("--horizon", type=int, default=1, help="Prediction horizon")
    parser.add_argument("--prediction_steps", type=int, default=1, help="Number of prediction steps")
    parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate")
    parser.add_argument("--weight_decay", type=float, default=0, help="weight decay")
    parser.add_argument("--seed", type=list_of_int, default=[1], help="experiment seed")
    parser.add_argument("--out_c", type=str, default='half', help="kernel size")
    parser.add_argument("--kernel", type=int, default=2, help="kernel size")
    parser.add_argument("--dilation", type=int, default=2, help="dilation")
    parser.add_argument("--conv_type", type=str, default='exp', help="convolution type")
    parser.add_argument("--num_conv", type=int, default=5, help='num convolution layers')
    parser.add_argument("--epochs", type=int, default=50, help="Number of epochs")
    parser.add_argument("--bsz", type=int, default=None, help="Number of epochs")
    parser.add_argument("--hpo", action='store_true')
    parser.add_argument("--small_chf", action='store_true')
    parser.add_argument("--depthwise", action='store_true')
    parser.add_argument("--sample_gap", type=int, default=1, help="sample gap")
    parser.set_defaults(small_chf=False, depthwise=False, hpo=False)
    args = parser.parse_args()
    torch.manual_seed(args.seed[0])
    np.random.seed(args.seed[0])
    print(args)
    print(f'GPU is available {torch.cuda.is_available()}')

    hps = retrieve_params(args.model, args.dataset)
    lookback = args.lookback or hps[0]
    lr = args.lr or hps[1]
    bsz = args.bsz or hps[2]
    patience = 10 if args.dataset == 'FI' else 10

    start_time = time.time()
    if args.hpo:
        assert len(args.seed) == 1
        s = args.seed[0]
        results = []
        print("doing HPO")
        

        for bsz in [8192, 4096, 1024, 512, 128, 64, 32]:
            for lr in [0.001, 0.0001, 0.00001, 0.000001]:
                tic = time.time()
                print(f'Learning Rate: {lr}, Batch Size: {bsz}, Patience: {patience}')

                print('loading data')
                train_loader, val_loader, test_loader = load_data(args.dataset, args.num_features, args.horizon, lookback, bsz, small=args.small_chf, seed=s, steps=args.prediction_steps, sample_gap=args.sample_gap)
                
                print('loading model')
                model = load_model(args.model, args.num_features+1, args.out_c, args.kernel, args.dilation, args.num_conv, args.conv_type, args.depthwise, lookback, bsz)
                optimizer = load_optimizer(args.model, model, lr)
                model_filename = 'hpo_{}_dataset{}_horizon{}_seed{}_features{}_lr{}_bsz{}_kernel{}_dilate{}_outc{}_convtype{}_numconvlayers{}_epoch'.format(
                    args.model.lower(), args.dataset, args.horizon, args.seed, args.num_features+1, lr,
                    bsz, args.kernel, args.dilation, args.out_c, args.conv_type, args.num_conv
                )

                print('running experiment')
                mse_test, corr_test, r2_test = run_experiment(
                    model, optimizer, train_loader, val_loader, test_loader, s, lr, patience, 
                    model_filename, args.weight_decay, args.epochs, LOSS_TYPE='MSE', LOSS_CALCULATION_MP_FLAG=True
                )
                results.append([mse_test, corr_test, r2_test])
                toc = time.time()
                print(f"Single Experiment Time: {time.strftime('%H:%M:%S', time.gmtime(toc - tic))}")

    else:
        print(f'Learning Rate: {lr}, Batch Size: {bsz}, Lookback {lookback}, Patience: {patience}')
        results = []
        for s in args.seed:
            tic = time.time()
            print(f'Seed {s}')

            print('loading data')
            train_loader, val_loader, test_loader = load_data(args.dataset, args.num_features, args.horizon, lookback, bsz, small=args.small_chf, seed=s, steps=args.prediction_steps, sample_gap=args.sample_gap)
            
            print('loading model')
            model = load_model(args.model, args.num_features+1, args.out_c, args.kernel, args.dilation, args.num_conv, args.conv_type, args.depthwise, bsz, lookback=lookback, pred_len=args.prediction_steps)

            model_filename = '{}_dataset{}_smallchf{}_horizon{}_seed{}_features{}_lr{}_bsz{}_kernel{}_dilate{}_outc{}_convtype{}_numconvlayers{}_epoch'.format(
                args.model.lower(), args.dataset, args.small_chf, args.horizon, args.seed, args.num_features+1, lr,
                bsz, args.kernel, args.dilation, args.out_c, args.conv_type, args.num_conv
            )
            
            optimizer = load_optimizer(args.model, model, lr)
            print('running experiment')
            mse_test, corr_test, r2_test = run_experiment(
                model, optimizer, train_loader, val_loader, test_loader, s, lr, patience, 
                model_filename, args.weight_decay, args.epochs, LOSS_TYPE='MSE', LOSS_CALCULATION_MP_FLAG=True
            )
            results.append([mse_test, corr_test, r2_test])
            toc = time.time()
            print(f"Single Experiment Time: {time.strftime('%H:%M:%S', time.gmtime(toc - tic))}")
        
        end_time = time.time()
        print(f"Total Experiment Time: {time.strftime('%H:%M:%S', time.gmtime(end_time - start_time))}")
        
        results = torch.tensor(results)
        print(f'Raw results: {results}')

        if len(args.seed) > 1:
            print('Excluding nan experiments...')
            success = torch.argwhere(~results[:,1].isnan())
            results = results[success]
            means = results.mean(dim=0).flatten()
            stds = results.std(dim=0).flatten()
            print(f'Experiment over seeds {args.seed}, using {[args.seed[i] for i in success]}')
            print(f'MSE: {means[0]}, Corr: {means[1]}, R^2: {means[2]}')

        else:
            results = results.flatten()
            print(f'MSE: {results[0].item()}, Corr: {results[1].item()}, R^2: {results[2].item()}')

if __name__ == "__main__":
    main()