"""
Training aids that provide the ability to train her + SAC models
"""
import os
import numpy as np
import torch
import random
import matplotlib.pyplot as plt
from collections import deque
import yaml
import time

#Import SAC Agent
from src.gift.agents.sac_agent import SAC_HER_Agent

#Import additional required modules
from src.gift.buffers.her_buffer import HERReplayBuffer, collect_her_samples
# from src.gift.utils.trainer import collect_her_samples
from src.gift.utils.evaluator import evaluate_agent, evaluate_cip
from src.gift.utils.visualization import plot_training_history
from src.data.cip_dataset import CIPDataset, get_dataloader

def set_random_seeds(seed=42):
    """Set random seeds to ensure repeatability"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

def train(dataset_collection, config, model_save_path='her_treatment_policy.pth', use_amp=False, load=False, logger=None):
    """
    Train a treatment strategy model, or load an existing model
    
    Args:
    dataset_collection: A collection of data sets containing train_f and val_f
    config: Configure Parameter Dictionary
    model_save_path: Model save path
    use_amp: whether to use hybrid precision training
    load: Whether to load the existing model directly without training
    
    Pingback:
    agent: Trained or loaded agent
    metrics: evaluation metrics
    """
    #Get various parameters from the configuration
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(config)
    min_history_length = config['model']['her_params']['min_history_length']
    max_history_length = config['model']['her_params']['max_history_length']
    future_length = config['model']['her_params']['future_length']
    buffer_size = config['model']['her_params']['buffer_size']
    batch_size = config['exp']['batch_size']
    
    goal_threshold = config['model']['her_params']['goal_threshold']
    k_future = config['model']['her_params']['k_future']
    reward_mode = config['model']['her_params']['reward_mode']
    hidden_dim = config['model']['hidden_dim']
    use_attention = config['model']['use_attention']
    attention_heads = config['model']['attention_heads']
    DR = config['model']['sac_params']['DR']
    recover = config['model']['sac_params']['recover']
    action_diff = config['model']['sac_params']['action_diff']
    is_cip = config['exp']['evaluation_mode'] == 'cip'
    use_cql = config['model']['sac_params']['use_cql']
    
    #Create SAC Agent
    agent = SAC_HER_Agent(
        dataset_collection,
        config,
        input_dim=config['dataset'].get('input_size'),
        output_dim=config['dataset'].get('output_size'),
        treatment_dim=config['dataset'].get('treatment_size'),
        static_dim=config['dataset'].get('static_size'),
        hidden_dim=hidden_dim,
        future_length=future_length,
        buffer_size=buffer_size,
        batch_size=batch_size,
        goal_threshold=goal_threshold,
        k_future=k_future,
        use_amp=use_amp,
        reward_mode=reward_mode,
        use_attention=use_attention,
        num_heads=attention_heads if use_attention else None,
        discount=config['model']['sac_params']['discount'],
        beta=config['model']['sac_params']['beta'],
        lr=config['model']['sac_params']['lr'],
        alpha=config['model']['sac_params']['alpha'],
        use_automatic_entropy=config['model']['sac_params']['use_automatic_entropy'],
        DR=DR,
        recover=recover,
        action_diff=action_diff,
        use_cql=use_cql,
        input_x=config['dataset'].get('input_x'),
    )

    #Set model save path
    save_path = f"sac_{model_save_path}"
    
    if load:
        #If you specify to load a model, try to load an existing model
        try:
            print(f"\ nTry to load model {save_path}...")
            agent.load(save_path)
            print(f"Model loaded successfully!")
        except FileNotFoundError:
            print(f"Model file {save_path} does not exist, will be trained...")
            load = False  #If loading fails, switch to training mode
        except Exception as e:
            print(f"Error loading model: {e}")
            print("Training will take place...")
            load = False  #If loading fails, switch to training mode
    
    if not load:
        #Create her playback buffer
        replay_buffer = agent.memory
        
        #Collect Her Sample
        collect_her_samples(
            dataset_collection,
            replay_buffer,
            min_history_length=min_history_length,
            max_history_length=max_history_length,
            future_length=future_length
        )

        r = 0
        for item in replay_buffer.buffer:
            if item[2] == 0:
                r += 1
        print(f"all:{len(replay_buffer.buffer)}, hit:{r}")
        
        training_iterations = config['exp']['max_epochs'] * (len(replay_buffer) // config['exp']['batch_size'])
        #Offline training
        print(f"\ nStart training SAC agents...")
        start_time = time.time()
        losses = agent.train_offline(training_iterations, progress_interval=config['exp']['log_freq'])
        train_time = time.time() - start_time
        
        #Save model
        agent.save(save_path)
        print(f"Model saved to {save_path}")
    
    #Evaluation Model
    print("\ nEvaluation Model...")
    start_time = time.time()
    metrics, complexity_info = evaluate(agent, dataset_collection, config, logger=logger)
    test_time = time.time() - start_time
    complexity_info['train_time'] = train_time
    complexity_info['test_time'] = test_time
    
    return agent, metrics, complexity_info

def evaluate(agent, dataset_collection, config, max_tau=6, logger=None):
    """
    Evaluate models for different tau values
    
    Args:
    agent: Trained agent
    dataset_collection: Data collection
    config: configuration parameters
    max_tau: maximum tau value, if None, use tau value in config
    
    Pingback:
    all_metrics: dictionary containing the results of all tau-value evaluations
    """
    orig_tau = config['exp']['tau']
    all_metrics = {}
    if 'mimic' in config['dataset']['name']:
        if config['exp']['test']:
            batch_size = int(config['dataset']['max_number'] * config['dataset']['split']['test'])
        else:
            batch_size = int(config['dataset']['max_number'] * config['dataset']['split']['val'])
    elif 'tumor' in config['dataset']['name']:
        if config['exp']['test']:
            batch_size = config['dataset']['num_patients']['test']
        else:
            batch_size = config['dataset']['num_patients']['val']

    for tau in range(1, max_tau + 1):
    # for i in range(1, 2):
        agent.future_length = tau
        logger.info(f"\ nEvaluation model, tau = {tau}...")
        #Temporarily modify the tau value in the configuration
        config['exp']['tau'] = tau
        
        if config['exp']['evaluation_mode'] == 'cip':
            if not config.exp.test:
                data = dataset_collection.val_f.data
            else:
                data = dataset_collection.test_f.data

            data_loader = get_dataloader(
                CIPDataset(data, config), 
                batch_size=batch_size,
                shuffle=False
            )
            metrics, complexity_info_tmp = evaluate_cip(
                agent,
                data_loader,
                dataset_collection,
                config
            )
            if tau == 1:
                complexity_info = complexity_info_tmp.copy()
        else:
            metrics = evaluate_agent(
                agent,
                dataset_collection,
                num_episodes=config['exp']['eval_episodes']
            )
            complexity_info = None
        
        all_metrics[tau] = metrics
        print_evaluation_results(metrics, tau, logger)
    
    #Restore original tau value
    config['exp']['tau'] = orig_tau
    print(f"complexity_info:{complexity_info}")
    return all_metrics, complexity_info

def print_evaluation_results(metrics, tau, logger):
    logger.info(f"Gift assessment results (tau = {tau}):")
    logger.info(f"Success rate: {metrics ['success_rate']: .2%}")
    logger.info(f"Average MSE: {metrics ['avg_mse']: .6f}")
    logger.info(f"Average RMSE: {metrics ['avg_rmse']: .6f}")
    logger.info(f"Average steps used: {metrics ['avg_steps_used']: .2f}/{tau}")
    logger.info(f"Early stop rate: {metrics ['early_stop_rate']: .2%}")