import numpy as np
import random
import torch
import time
import matplotlib.pyplot as plt

from src.gift.utils.evaluator import evaluate_agent
from src.gift.utils.visualization import plot_training_history

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from src.data.her_data_generator import create_history_treatment_goal_samples


def train_agent(agent, dataset_collection, model_save_path=None, 
                min_history_length=10, max_history_length=20, 
                training_iterations=50000, eval_interval=5000):
    """
    Training agent
    
    Args:
        agent: Reinforcement learning agent
        dataset_collection: Dataset collection
        model_save_path: Model save path
        min_history_length: Minimum history length
        max_history_length: maximum history length
        training_iterations: number of training iterations
        eval_interval: Evaluation interval
        
    Pingback:
        agent: post-training agent
        metrics: evaluation metrics
    """
    #Default save path
    if model_save_path is None:
        model_save_path = f"best_model_{agent.algorithm}.pth"
    
    #Collect Her Sample
    collect_her_samples(
        dataset_collection,
        agent.memory,
        min_history_length=min_history_length,
        max_history_length=max_history_length,
        future_length=agent.future_length
    )
    
    #Offline training
    losses, eval_results = agent.train_offline(
        training_iterations, 
        progress_interval=1000,
        eval_interval=eval_interval,
        dataset_collection=dataset_collection
    )
    
    #Draw Training History
    plot_training_history(agent)
    
    #final evaluation
    print("\ nFinal Assessment...")
    metrics = evaluate_agent(
        agent, 
        dataset_collection,
        num_episodes=2000
    )
    
    #Save Final Model
    if model_save_path.endswith(".pth"):
        final_path = model_save_path.replace(".pth", "_final.pth")
    else:
        final_path = model_save_path + "_final"
    agent.save(final_path)
    
    return agent, metrics, losses, eval_results