# eval/eval_firl.py
import torch, csv
from train.train_firl import set_seed 
from algorithms.finsler_actor_critic import Actor

def evaluate_policy(env, actor: Actor, episodes=100, max_ep_len=1000):
    """Roll out a deterministic policy (mean actions) for given episodes and collect metrics."""
    results = {'episode_length': [], 'total_cost': [], 'success': []}
    for ep in range(episodes):
        state = env.reset()
        total_cost = 0.0
        done = False
        t = 0
        while not done and t < max_ep_len:
            state_t = torch.tensor(state, dtype=torch.float32)
            with torch.no_grad():
                mean, std = actor(state_t)
                # Use mean action for evaluation (greedy policy)
                action = torch.tanh(mean).cpu().numpy()
            next_state, reward, done, info = env.step(action)
            cost = -reward  # since reward = -cost
            total_cost += cost
            state = next_state
            t += 1
        results['episode_length'].append(t)
        results['total_cost'].append(total_cost)
        # Success if not done prematurely (i.e., either reached max len or some goal without falling)
        success = 1.0 if (done is False or t >= max_ep_len) else 0.0
        results['success'].append(success)
    # Aggregate metrics
    avg_cost = float(np.mean(results['total_cost']))
    cvar_cost = float(np.mean(sorted(results['total_cost'])[-max(1,int(0.1*episodes)):]))  # worst 10%
    success_rate = float(np.mean(results['success']))
    avg_len = float(np.mean(results['episode_length']))
    return avg_cost, cvar_cost, success_rate, avg_len

def eval_firl(config_file, model_path, out_file, test_scenarios=None):
    # Load config and environment similar to training
    with open(config_file, 'r') as f:
        cfg = yaml.safe_load(f)
    env_name = cfg['env_name']
    # We will test possibly multiple scenarios, e.g., different slopes or disturbances
    scenarios = test_scenarios or [ 
        {'incline': cfg.get('incline', 0), 'disturbance': False, 'actuator_mode': None}
    ]
    # Load actor network
    state_dim = gym.make(env_name).observation_space.shape[0]
    action_dim = gym.make(env_name).action_space.shape[0]
    actor = Actor(state_dim, action_dim)
    actor.load_state_dict(torch.load(model_path, map_location='cpu'))
    actor.eval()
    # Prepare output CSV
    fieldnames = ['scenario', 'incline', 'disturbance', 'actuator_mode', 'avg_cost', 'cvar_cost', 'success_rate']
    with open(out_file, 'w', newline='') as csvfile:
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
        writer.writeheader()
        # Evaluate for each scenario
        for sc in scenarios:
            # Construct environment with given scenario modifications
            env = gym.make(env_name)
            if sc.get('incline', 0):
                env = InclineWrapper(env, slope_degrees=sc['incline'])
            if sc.get('disturbance', False):
                env = DisturbanceWrapper(env, force_mag=50.0, interval=100, noise_actions=False)
            if sc.get('actuator_mode', None):
                mode = sc['actuator_mode']
                if mode == 'scale':
                    env = ActuatorModificationWrapper(env, failure_mode='scale', scale=sc.get('scale', 0.5))
                elif mode == 'drop':
                    env = ActuatorModificationWrapper(env, failure_mode='drop', failure_step=100)
            env = FinslerRewardWrapper(env, we=cfg.get('w_e',1.0), wd=cfg.get('w_d',1.0), wf=cfg.get('w_f',1.0),
                                       beta_coef=cfg.get('beta_coef',50.0), lambda_lat=cfg.get('lambda_lat',1.0))
            avg_cost, cvar_cost, success_rate, avg_len = evaluate_policy(env, actor, episodes=100, max_ep_len=cfg.get('max_episode_length',1000))
            writer.writerow({
                'scenario': str(sc),
                'incline': sc.get('incline', 0),
                'disturbance': sc.get('disturbance', False),
                'actuator_mode': sc.get('actuator_mode', None),
                'avg_cost': avg_cost,
                'cvar_cost': cvar_cost,
                'success_rate': success_rate
            })
            print(f"Scenario {sc}: AvgCost={avg_cost:.2f}, CVaR_Cost={cvar_cost:.2f}, Success={success_rate:.2f}")
