import os
import yaml
import pickle
import argparse
from copy import deepcopy

import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)

parser = argparse.ArgumentParser()
parser.add_argument('--env-name', '-e', default='hopper-medium-v2', help='task environment name')
parser.add_argument('--rollout-frequency', '-r', default=1, type=int, help='rollout frequency')
parser.add_argument('--num-episodes', default=5, type=int, help='number of episodes to evaluate')
parser.add_argument('--save-video', action='store_true', help='save video of evaluation')
parser.add_argument('--gpu', default=0, type=int, help='gpu number')
args = parser.parse_args()

os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
os.environ['IMAGEIO_FFMPEG_EXE'] = '/usr/bin/ffmpeg'
print('Using GPU:', args.gpu)

import gym
import d4rl
import torch
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt 
import matplotlib.animation as animation

from utils import load_config
from models import load_policy, MLP, FlowPolicy
from preflow import FlowMatching


def evaluate_flow_policy(env, flow_policy, device='cpu', file_name=None):
    
    state = env.reset()
    done = False
    total_reward = 0
    frames = []

    pbar = tqdm(range(env._max_episode_steps))
    while not done:
        if args.save_video:
            frame = env.render(mode='rgb_array')
            frames.append(frame)
        qpos = np.copy(env.data.qpos[:])
        qvel = np.copy(env.data.qvel[:])
        
        with torch.no_grad():
            action = flow_policy(
                state, qpos, qvel,
                use_torchdiffeq=False
            )
        state, reward, done, _ = env.step(action)
        total_reward += reward
        pbar.update(1)
        pbar.set_description(f'Total Reward: {total_reward:.2f}')
    
    if args.save_video:
        fig, ax = plt.subplots()
        im = ax.imshow(frames[0], animated=True)

        def update(i):
            im.set_array(frames[i])
            return im,

        if file_name is not None:
            animation_fig = animation.FuncAnimation(fig, update, frames=len(frames))
            animation_fig.save(f'./results/video/{file_name}.gif')
    
    return total_reward


def main(env_name=None, beta=None):
    
    device = f'cuda' if torch.cuda.is_available() else 'cpu'
    print('Using device:', device)
    
    if env_name is not None:
        args.env_name = env_name
    
    config = load_config(args)
        
    if beta is not None:
        config.rlhf.beta = beta
        config.dpo.beta = beta
    
    pretrained_model = load_policy(config, device=device, path=config.policy.path)
    
    model = torch.load(
        os.path.join(config.model.save_dir, f'{args.env_name}.pth'), 
        map_location=device
    ).to(device)
    print('Model loaded from:', config.model.save_dir)
    
    env = gym.make(args.env_name)
    flow = FlowMatching(model, device=device)
    flow_policy = FlowPolicy(
        deepcopy(env), flow, pretrained_model, 
        seg_len=config.data.seg_len, 
        action_idx=config.action_idx,
        max_action=config.env.max_action,
        rollout_frequency=args.rollout_frequency,
        flow_iteration=config.flow_iteration,
        device=device
    )

    flow_total_reward = []
    
    for episode in range(args.num_episodes):
        
        env = gym.make(args.env_name)
        total_reward = evaluate_flow_policy(
            deepcopy(env), flow_policy, 
            device=device,
            file_name=f'{args.env_name}_flow_{episode}'
        )
        flow_total_reward.append(total_reward)
        print('Total Reward (Flow Matching):', total_reward)
    
    print('Evaluation complete!\n')
    print(f'PFM Average Return: {np.mean(flow_total_reward):.2f}, {np.std(flow_total_reward):.2f}')


if __name__ == '__main__':
    main()