# Referred to: https://github.com/PrieureDeSion/drive-any-robot

from typing import Any
import wandb
import argparse
import yaml
import time
import os
import random
import torch
import torch.nn as nn
from torchvision import transforms
import numpy as np
from utils import load_model, get_saved_optimizer, train_eval_loop
from datasets import MapPretrainDataset
from model import IGNet
from torch.utils.data import DataLoader
from PIL import Image
from navigation_evaluator import NavigationEvaluator

class IGNetPolicy:
    def __init__(self, model, data_inference, dataset, device) -> None:
        self.model = model
        self.data_inference = data_inference
        self.dataset = dataset
        self.device = device
        
    def __call__(self, state_path, goal_path) -> Any:
        while True:
            try:
                state_image = Image.open(state_path).convert('RGB')
                goal_image = Image.open(goal_path).convert('RGB')
                break
            except:
                time.sleep(0.02)
        
        image_tensors = self.dataset.process_images([state_image, goal_image]).to(self.device)
        self.data_inference['images'][:, :2] = image_tensors
        
        with torch.no_grad():
            outputs = self.model(self.data_inference, eval_loss=False)
        
        # We use a trick that ensemble many prediction at once
        
        # print('input:')
        # for key, val in data_inference.items():
        #     print(key, val.shape)
        # print('output:')
        # for key, val in outputs.items():
        #     print(key, val.shape)
        local_paths_pred = outputs['local_paths_pred'][:, 1]
        global_paths_pred = outputs['global_paths_pred'][:, 1]
        dist_pred = outputs['dist_pred'][:, 1]
        action_logits = outputs['action_logits'][:, 1]  # (16, 5, action_num)
        
        current_action_logits = action_logits[:, 0].mean(0)  # (action_num)
        action_distribution = torch.nn.Softmax()(current_action_logits)
        sample_method = 'softmax'
        if sample_method == 'argmax':
            current_action = current_action_logits.argmax().item()
        elif sample_method == 'softmax':
            action_distribution_arr = action_distribution.cpu().numpy().tolist()
            action_distribution_arr = [round(pr, 2) for pr in action_distribution_arr]
            print('action_distribution:', action_distribution_arr)
            current_action = torch.multinomial(action_distribution, 1)[0].item()
        elif sample_method == 'sharp-softmax':
            action_distribution_sharp = torch.nn.Softmax()(current_action_logits * 1.5)
            action_distribution_sharp_arr = action_distribution_sharp.cpu().numpy().tolist()
            action_distribution_sharp_arr = [round(pr, 2) for pr in action_distribution_sharp_arr]
            print('action_distribution:', action_distribution_sharp_arr)
            current_action = torch.multinomial(action_distribution_sharp, 1)[0].item()
        else:
            raise NotImplementedError
        
        return current_action

def eval_main(configs):
    # Set up GPU devices
    if torch.cuda.is_available():
        os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
        if "gpu_ids" not in configs:
            configs["gpu_ids"] = [0]
        # os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
        #     [str(x) for x in configs["gpu_ids"]]
        # )
        # print("Using cuda devices:", os.environ["CUDA_VISIBLE_DEVICES"])
        
    # Set up device
    device = torch.device(f'cuda:{configs["gpu_ids"][0]}' if torch.cuda.is_available() else 'cpu')
    
    # Set up seed
    if 'seed' in configs:
        np.random.seed(configs['seed'])
        torch.manual_seed(configs['seed'])

    # Create the model
    model = IGNet(configs)
    if len(configs["gpu_ids"]) > 1:
        model = nn.DataParallel(model, device_ids=configs["gpu_ids"])
    model = model.to(device)
    
    print('Build PM-Net model finished')
    num_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    num_new_parameters = sum(
        p.numel() for name, p in model.named_parameters() if p.requires_grad and not name.startswith('image_encoder')
    )
    print(f'{num_parameters = }')
    print(f'{num_new_parameters = }')
    
    # Load the dataset
    train_dataset = MapPretrainDataset(
        configs=configs,
        data_split_type='train', 
        transform=model.module.image_processor if len(configs["gpu_ids"]) > 1 else model.image_processor,
        use_huggingface_transform=model.module.huggingface_image_processor if len(configs["gpu_ids"]) > 1
        else model.huggingface_image_processor,
    )
    
    # Create Dataloader
    train_loader = DataLoader(
        train_dataset, 
        batch_size=configs['batch_size'], 
        shuffle=True,
        num_workers=configs['num_workers'],
        drop_last=True
    )
    
    current_epoch = 0
    if 'load_run' in configs:
        # Load checkpoints
        load_model_path = os.path.join('logs', configs['project_name'], configs['load_run'])
        print(f'Loading model from {load_model_path}')
        latest_checkpoint = torch.load(load_model_path, map_location=device)
        load_model(model, latest_checkpoint)
        optimizer = get_saved_optimizer(latest_checkpoint, device)
        current_epoch = latest_checkpoint['epoch'] + 1
    
    model.eval()
    
    iterator = iter(train_loader)
    data = next(iterator)
    
    data_inference = {}
    
    for key in ['images', 'map_positions', 'map_yaws', 'pos_scale', 'location_types']:
        data_inference[key] = data[key].to(device)
    
    IGNet_policy = IGNetPolicy(model, data_inference, train_dataset, device)
    
    # action = IGNet_policy(state_path, goal_path)
    
    game_configs = {
        'binary_path': '../binaries/LinuxNoEditor_08_01_3_starts/ShooterGame.sh', 
        'port': configs['port'], 
        'cuda_id': configs['gpu_ids'][0], 
        'num_test_episodes': 50, 
        'eval_file_path': 'eval_dataset.pkl'
    }
        
    navigation_evaluator = NavigationEvaluator(**game_configs)
    performance = navigation_evaluator.evaluate(IGNet_policy, configs['run_name'])
    navigation_evaluator.end()
    print(f'performance: {performance}')

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="PM-Net Evaluation")
    
    # Config files
    parser.add_argument(
        '--config', 
        default="configs/default.yaml",
        type=str,
        help="Path to the config file.",
    )
    
    parser.add_argument(
        '--run-name', 
        default='IGNet',
        type=str
    )
    
    args = parser.parse_args()
    
    with open(args.config, "r") as f:
        configs = yaml.safe_load(f)
    
    # Experiment Name setting, and experiment folder setting. 
    # configs["run_name"] += "_" + time.strftime("%Y_%m_%d_%H_%M_%S")
    configs["run_name"] = args.run_name + "_" + time.strftime("%Y_%m_%d_%H_%M_%S")
    configs["project_folder"] = os.path.join(
        "logs", configs["project_name"], configs["run_name"]
    )
    
    os.makedirs(configs["project_folder"])
    
    # if configs["use_wandb"]:
    #     wandb.login()
    #     wandb.init(
    #         project=configs["project_name"], settings=wandb.Settings(start_method="fork")
    #     )
        
    #     wandb.run.name = configs["run_name"]
    #     # Update the wandb args with the training configurations
    #     if wandb.run:
    #         wandb.config.update(configs)
        
    print('All configs:')
    for key, val in configs.items():
        print(f'{key}: {val}')
    
    eval_main(configs)