#!/usr/bin/env python3


import torch
import torch.nn as nn
import gymnasium as gym
import numpy as np
import os, sys, glob
import ctypes
from dataclasses import dataclass
import json
import shimmy

@dataclass
class Args:
    test: float = 10


def _load_vmf_geo_core():

    torch_lib = os.path.join(os.path.dirname(torch.__file__), 'lib')
    libc10 = os.path.join(torch_lib, 'libc10.so')
    if os.path.exists(libc10):
        ctypes.CDLL(libc10, mode=ctypes.RTLD_GLOBAL)

    try:
        import geo_core
        print("GAC core loaded successfully")
        return geo_core
    except ImportError:

        import importlib.util
        so_files = glob.glob("gac_core*.so")
        if so_files:
            spec = importlib.util.spec_from_file_location("gac_core", so_files[0])
            vmf_geo_core = importlib.util.module_from_spec(spec)
            sys.modules["gac_core"] = vmf_geo_core
            spec.loader.exec_module(vmf_geo_core)
            print(f"GAC core loaded from {so_files[0]}")
            return vmf_geo_core
        else:
            raise RuntimeError("❌ Cannot find gac.so")



vmf_geo_core = _load_vmf_geo_core()


class GeoActor(nn.Module):


    def __init__(self, obs_dim, action_dim, args):
        super().__init__()
        self.args = args
        self.action_dim = action_dim


        self.backbone = nn.Sequential(
            nn.Linear(obs_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
        )

        self.mu_head = nn.Linear(256, action_dim)

        self.kappa_head = nn.Sequential(
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )


        self.geo_head = nn.Sequential(
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, action_dim)
        )

    def get_params_dict(self):

        return {
            'backbone_w1': self.backbone[0].weight,
            'backbone_b1': self.backbone[0].bias,
            'backbone_w2': self.backbone[2].weight,
            'backbone_b2': self.backbone[2].bias,
            'mu_head_w': self.mu_head.weight,
            'mu_head_b': self.mu_head.bias,
            'kappa_head_w1': self.kappa_head[0].weight,
            'kappa_head_b1': self.kappa_head[0].bias,
            'kappa_head_w2': self.kappa_head[2].weight,
            'kappa_head_b2': self.kappa_head[2].bias,
            'geo_head_w1': self.geo_head[0].weight,
            'geo_head_b1': self.geo_head[0].bias,
            'geo_head_w2': self.geo_head[2].weight,
            'geo_head_b2': self.geo_head[2].bias,
        }


    def get_action(self, x, deterministic=False):

        params = self.get_params_dict()


        action, entropy, mean = vmf_geo_core.get_action_with_geo(
            x,
            **params,
            kappa_init=getattr(self.args, 'kappa_init', 2.0),
            action_scale=self.action_scale,
            action_bias=self.action_bias,
            deterministic=deterministic
        )

        return action


def evaluate_vmf_geo(checkpoint_path, num_episodes=20, deterministic=True):

    print("=" * 60)
    print("GAC Evaluation")
    print("=" * 60)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)

    args = checkpoint.get('args', Args())
    global_step = checkpoint.get('global_step', 0)

    print(f"Checkpoint: {os.path.basename(checkpoint_path)}")
    print(f"Environment: {args.env_id}")
    print(f"Training steps: {global_step:,}")
    print(f"Device: {device}")


    if args.env_id.startswith("dm_control/"):

        try:
            from gymnasium.wrappers import FlattenObservation
            env = gym.make(args.env_id)
            env = FlattenObservation(env)
            env = gym.wrappers.RecordEpisodeStatistics(env)


            dummy_obs, _ = env.reset()
            if hasattr(env, 'observation_space') and env.observation_space is not None:
                obs_dim = env.observation_space.shape[0]
            else:
                obs_dim = dummy_obs.shape[0]

        except Exception as e:
            print(f"⚠️ : {e}")
            print("...")
            
            env_name = args.env_id.split('/')[-1]
            if 'cheetah' in env_name.lower():
                env_name = 'HalfCheetah-v4'
            elif 'walker' in env_name.lower():
                env_name = 'Walker2d-v4'
            print(f"📝 : {env_name}")

            env = gym.make(env_name)
            env = gym.wrappers.RecordEpisodeStatistics(env)
            obs_dim = env.observation_space.shape[0]
    else:

        env = gym.make(args.env_id)
        env = gym.wrappers.RecordEpisodeStatistics(env)
        obs_dim = env.observation_space.shape[0]


    if hasattr(env.action_space, 'shape'):
        action_dim = env.action_space.shape[0]
    else:

        action_dim = np.prod(env.action_space.shape) if hasattr(env.action_space, 'shape') else 6

    print(f"📐 Observation dim: {obs_dim}, Action dim: {action_dim}")


    actor = GeoActor(obs_dim, action_dim, args).to(device)


    actor.load_state_dict(checkpoint['model_state_dict'], strict=False)


    if hasattr(env.action_space, 'high') and hasattr(env.action_space, 'low'):
        actor.action_scale = torch.tensor(
            (env.action_space.high - env.action_space.low) / 2.0,
            dtype=torch.float32
        ).to(device)

        actor.action_bias = torch.tensor(
            (env.action_space.high + env.action_space.low) / 2.0,
            dtype=torch.float32
        ).to(device)
    else:

        actor.action_scale = torch.ones(action_dim, dtype=torch.float32).to(device)
        actor.action_bias = torch.zeros(action_dim, dtype=torch.float32).to(device)

    actor.eval()


    mode = 'Deterministic' if deterministic else 'Stochastic'
    print(f"\n📊 Running {num_episodes} episodes ({mode})")
    print("-" * 40)

    episode_returns = []

    for episode in range(num_episodes):
        obs, _ = env.reset(seed=episode)


        if not isinstance(obs, np.ndarray):
            obs = np.array(obs, dtype=np.float32)

        episode_return = 0
        episode_length = 0

        while True:
            obs_tensor = torch.FloatTensor(obs).unsqueeze(0).to(device)

            with torch.no_grad():
                action = actor.get_action(obs_tensor, deterministic=deterministic)
                action_np = action.cpu().numpy().flatten()


                if hasattr(env.action_space, 'high') and hasattr(env.action_space, 'low'):
                    action_np = np.clip(action_np, env.action_space.low, env.action_space.high)
                else:
                    action_np = np.clip(action_np, -1.0, 1.0)

            obs, reward, terminated, truncated, info = env.step(action_np)


            if not isinstance(obs, np.ndarray):
                obs = np.array(obs, dtype=np.float32)

            episode_return += reward
            episode_length += 1

            if terminated or truncated:
                episode_returns.append(episode_return)
                print(f"Episode {episode + 1:2d}: Return = {episode_return:8.2f}, Length = {episode_length:4d}")
                break

    env.close()


    print("-" * 40)
    mean_return = np.mean(episode_returns)
    std_return = np.std(episode_returns)

    print(f"📈 Mean Return: {mean_return:.2f} ± {std_return:.2f}")
    print(f"📈 Max Return: {np.max(episode_returns):.2f}")
    print(f"📈 Min Return: {np.min(episode_returns):.2f}")
    print("=" * 60)

    results = {
        'checkpoint': os.path.basename(checkpoint_path),
        'environment': args.env_id,
        'training_steps': global_step,
        'evaluation_mode': mode,
        'num_episodes': num_episodes,
        'mean_return': float(mean_return),
        'std_return': float(std_return),
        'max_return': float(np.max(episode_returns)),
        'min_return': float(np.min(episode_returns)),
        'episodes': [float(r) for r in episode_returns]
    }


    result_file = checkpoint_path.replace('.pt', f'_eval_{mode.lower()}.json')
    # with open(result_file, 'w') as f:
    #     json.dump(results, f, indent=2)
    # print(f"💾 Results saved to: {os.path.basename(result_file)}")

    return mean_return



def main():
    import argparse

    parser = argparse.ArgumentParser(description='Evaluate GAC on DMC Environments')
    parser.add_argument('--checkpoint', type=str, default=None,
                        help='Specific checkpoint to evaluate')
    parser.add_argument('--episodes', type=int, default=20,
                        help='Number of evaluation episodes')
    parser.add_argument('--stochastic', action='store_true',
                        help='Use stochastic policy (default: deterministic)')
    parser.add_argument('--all', action='store_true', default=True,
                        help='Evaluate all environments (default: True)')

    args = parser.parse_args()
    deterministic = not args.stochastic


    if args.checkpoint is not None:
        evaluate_vmf_geo(args.checkpoint, args.episodes, deterministic)
        return


    print("\n" + "=" * 70)
    print("🚀 GAC: Evaluating ALL DMC Environments")
    print("=" * 70 + "\n")

    all_checkpoints = [
        "gac_pretrain_model/fish-upright-v0_step_500000.pt",
        "gac_pretrain_model/walker-walk-v0_step_500000.pt",
        "gac_pretrain_model/walker-run-v0_step_500000.pt",
        "gac_pretrain_model/cheetah-run-v0_step_500000.pt",
        "gac_pretrain_model/quadruped-walk-v0_step_500000.pt",
        "gac_pretrain_model/quadruped-run-v0_step_500000.pt",
    ]

    results_summary = []

    for i, checkpoint_pattern in enumerate(all_checkpoints, 1):
        checkpoints = glob.glob(checkpoint_pattern)

        if not checkpoints:
            print(f"⚠️  Checkpoint not found: {checkpoint_pattern}")
            continue

        checkpoint_path = checkpoints[0]
        env_name = os.path.basename(checkpoint_path).split('_')[0]

        print(f"\n{'=' * 70}")
        print(f"[{i}/6] Evaluating: {env_name}")
        print(f"{'=' * 70}")

        try:
            mean_return = evaluate_vmf_geo(checkpoint_path, args.episodes, deterministic)
            results_summary.append({
                'environment': env_name,
                'mean_return': mean_return,
                'checkpoint': os.path.basename(checkpoint_path)
            })
        except Exception as e:
            print(f"❌ Error evaluating {env_name}: {e}")
            results_summary.append({
                'environment': env_name,
                'mean_return': 'FAILED',
                'checkpoint': os.path.basename(checkpoint_path)
            })

    print("\n" + "=" * 70)
    print("📊 FINAL RESULTS SUMMARY")
    print("=" * 70)
    print(f"{'Environment':<25} {'Mean Return':>15} {'Checkpoint':<30}")
    print("-" * 70)

    for result in results_summary:
        env = result['environment']
        ret = f"{result['mean_return']:.2f}" if isinstance(result['mean_return'], float) else result['mean_return']
        ckpt = result['checkpoint'][:28] + "..." if len(result['checkpoint']) > 30 else result['checkpoint']
        print(f"{env:<25} {ret:>15} {ckpt:<30}")

    print("=" * 70)
    print(f"✅ Evaluation complete! Total environments: {len(results_summary)}")
    print("=" * 70 + "\n")


if __name__ == "__main__":
    main()