import pickle
import numpy as np
import pandas as pd
import yaml
from data import get_data
from rl_environment import TradingEnvironment
from rl_model import PPOTrainer
import torch
from datetime import datetime
import os
import matplotlib.pyplot as plt
from market_impact import MarketImpactCalculator

def load_config():
    with open('config.yaml', 'r') as f:
        return yaml.safe_load(f)

def load_data(ticker, from_date, until_date, override=False):
    # Parameters
    # check if data is already in csv
    if os.path.exists('data/intraday_data.csv') and os.path.exists('data/daily_data.csv') and not override:
        df_intra = pd.read_csv('data/intraday_data.csv')
        df_daily = pd.read_csv('data/daily_data.csv')
    else:
        df_intra, df_daily = get_data(ticker, from_date, until_date)  
        # save data to csv
        df_intra.to_csv('data/intraday_data.csv', index=False)
        df_daily.to_csv('data/daily_data.csv', index=False)
    return df_intra, df_daily

def backtest_rl(config, df_intra, df_daily, model_path, consider_market_impact=True):
    """Backtest the trained RL model"""
    # Create environment
    env = TradingEnvironment(df_intra, df_daily, config, consider_market_impact=consider_market_impact)
    
    # Load trained model
    trainer = PPOTrainer(
        state_dim=env.observation_space.shape[0],
        action_dim=env.action_space.shape[0],
        hidden_dim=config['rl']['hidden_dim']
    )
    trainer.load(model_path)
    
    # Initialize tracking variables
    portfolio_values = []
    positions = []
    trades = []
    daily_returns = []
    
    # Reset environment
    state = env.reset()
    done = False
    
    while not done:
        # Select action using trained model
        action = trainer.select_action(state)
        
        # Take action
        next_state, reward, done, info = env.step(action)
        
        # Record results
        portfolio_values.append(info['portfolio_value'])
        positions.append(info['position'])
        daily_returns.append(info['returns'])
        
        # Update state
        state = next_state
    
    # Calculate performance metrics
    portfolio_values = np.array(portfolio_values)
    daily_returns = np.array(daily_returns)
    
    # Calculate cumulative returns
    cumulative_returns = (portfolio_values / portfolio_values[0]) - 1
    
    # Calculate Sharpe ratio
    sharpe_ratio = np.sqrt(252) * np.mean(daily_returns) / (np.std(daily_returns) + 1e-8)
    
    # Calculate maximum drawdown
    peak = np.maximum.accumulate(portfolio_values)
    drawdown = (portfolio_values - peak) / peak
    max_drawdown = np.min(drawdown) 
    return {
        'portfolio_values': portfolio_values,
        'positions': positions,
        'daily_returns': daily_returns,
        'cumulative_returns': cumulative_returns,
        'sharpe_ratio': sharpe_ratio,
        'max_drawdown': max_drawdown
    }

def final_backtest_rl(ticker, model_path):
    # Extract model name from path
    model_name = os.path.basename(model_path)
    model_name = model_name.split('.')[0]
    
    # Load configuration
    config = load_config()
    
    # Load data for backtesting period
    from_date = '2022-05-09'
    until_date = '2022-11-09'
    df_intra, df_daily = load_data(ticker, from_date, until_date, override=True)
    if df_intra.empty or df_daily.empty:
        print(f"Warning: API returned empty data for {ticker} from {from_date} to {until_date}")
        return
    
    # Run backtest without market impact
    print("Running backtest without market impact...")
    results_no_impact = backtest_rl(config, df_intra, df_daily, model_path, consider_market_impact=False)
    
    # Run backtest with market impact
    print("\nRunning backtest with market impact...")
    results_with_impact = backtest_rl(config, df_intra, df_daily, model_path, consider_market_impact=True)
    
    # save results (dict) to pickle
    with open(f'backtest_rl_results/{model_name}_no_impact.pkl', 'wb') as f:
        pickle.dump(results_no_impact, f)
    with open(f'backtest_rl_results/{model_name}_with_impact.pkl', 'wb') as f:
        pickle.dump(results_with_impact, f)

    # Plot comparison
    plt.figure(figsize=(12, 6))
    x_values = range(len(results_no_impact['portfolio_values']))
    plt.plot(x_values, results_no_impact['portfolio_values'], label='Without Market Impact')
    plt.plot(x_values, results_with_impact['portfolio_values'], label='With Market Impact')
    
    # Calculate SPY performance using daily returns, starting from lookback period
    initial_portfolio_value = results_no_impact['portfolio_values'][0]
    lookback_period = config['backtesting'].get('lookback_period', 30)
    spy_returns = df_daily['close'].pct_change().dropna()
    spy_returns = spy_returns[lookback_period:lookback_period + len(results_no_impact['portfolio_values'])]
    spy_cumulative = initial_portfolio_value * (1 + spy_returns).cumprod()
    plt.plot(x_values, spy_cumulative.values, label='SPY Price', linestyle='--', alpha=0.7)
    
    plt.title('Portfolio Value Comparison')
    plt.xlabel('Trading Days')
    plt.ylabel('Portfolio Value ($)')
    plt.legend()
    plt.grid(True)
    
    plt.tight_layout()
    plt.savefig(f'backtest_rl_results/{model_name}_comparison.png')
    plt.close()
    
    # Create comparison results dictionary
    comparison_results = {
        'final_portfolio_value': {
            'without_impact': results_no_impact['portfolio_values'][-1],
            'with_impact': results_with_impact['portfolio_values'][-1],
            'difference': results_with_impact['portfolio_values'][-1] - results_no_impact['portfolio_values'][-1]
        },
        'total_return': {
            'without_impact': results_no_impact['cumulative_returns'][-1] * 100,
            'with_impact': results_with_impact['cumulative_returns'][-1] * 100,
            'difference': (results_with_impact['cumulative_returns'][-1] - results_no_impact['cumulative_returns'][-1]) * 100
        },
        'sharpe_ratio': {
            'without_impact': results_no_impact['sharpe_ratio'],
            'with_impact': results_with_impact['sharpe_ratio'],
            'difference': results_with_impact['sharpe_ratio'] - results_no_impact['sharpe_ratio']
        },
        'max_drawdown': {
            'without_impact': results_no_impact['max_drawdown'] * 100,
            'with_impact': results_with_impact['max_drawdown'] * 100,
            'difference': (results_with_impact['max_drawdown'] - results_no_impact['max_drawdown']) * 100
        }
    }
    
    # Print comparison for reference
    print("\nComparison Results:")
    print("=" * 50)
    print(f"{'Metric':<25} {'Without Impact':<15} {'With Impact':<15} {'Difference':<15}")
    print("-" * 75)
    print(f"{'Final Portfolio Value':<25} ${comparison_results['final_portfolio_value']['without_impact']:,.2f} ${comparison_results['final_portfolio_value']['with_impact']:,.2f} ${comparison_results['final_portfolio_value']['difference']:,.2f}")
    print(f"{'Total Return':<25} {comparison_results['total_return']['without_impact']:.2f}% {comparison_results['total_return']['with_impact']:.2f}% {comparison_results['total_return']['difference']:.2f}%")
    print(f"{'Sharpe Ratio':<25} {comparison_results['sharpe_ratio']['without_impact']:.2f} {comparison_results['sharpe_ratio']['with_impact']:.2f} {comparison_results['sharpe_ratio']['difference']:.2f}")
    print(f"{'Maximum Drawdown':<25} {comparison_results['max_drawdown']['without_impact']:.2f}% {comparison_results['max_drawdown']['with_impact']:.2f}% {comparison_results['max_drawdown']['difference']:.2f}%")
    print("=" * 50)
    
    return comparison_results

if __name__ == "__main__":
    # enumerate all files in models folder
    for file in os.listdir('models'):
        if file.endswith('.pth'):
            ticker = file.split('_')[0]
            model_path = os.path.join('models', file)
            final_backtest_rl(ticker, model_path) 
    # ticker = 'SPY'
    # model_path = os.path.join('models', 'best_model.pth')
    # final_backtest_rl(ticker, model_path) 