import os
import sys
import random
import numpy as np
import pdb

import torch
import torch.nn as nn
import torch.utils.data as data
from torch.utils.tensorboard import SummaryWriter

from arguments import get_args
import init_path
from utils_bc import utils
from data_loader import Dataloader

from bc_agent import BC_Agent
from models.bc_model import BC_MODEL

from utils_bc import utils_interactive_eval
from utils_bc.utils import save_model, load_pretrained_model
from utils_bc.utils_llm import get_pretrained_tokenizer
from interactive_interface import interactive_interface_fn


import GPUtil
gpu_chunks = GPUtil.getAvailable(order = 'memory', limit = 16, maxLoad = 0.2, maxMemory = 0.2, includeNan=False, excludeID=[], excludeUUID=[])
gpu_chunks = [[tem] for tem in gpu_chunks]
gpu_id = gpu_chunks[0][0]
print('gpu_id', gpu_id)
torch.cuda.set_device(gpu_id)


def get_logger(args, log_path):
    if os.path.exists(log_path):
        os.remove(log_path)

    import logging
    a_logger = logging.getLogger()
    a_logger.setLevel(logging.INFO)

    output_file_handler = logging.FileHandler(log_path)
    stdout_handler = logging.StreamHandler(sys.stdout)

    a_logger.addHandler(output_file_handler)
    a_logger.addHandler(stdout_handler)
    logging = a_logger
    return logging


def main():
    args = get_args()
    main_single(0, args)


def main_single(gpu, args):
    random.seed(args.seed + gpu)
    np.random.seed(args.seed + gpu)
    torch.manual_seed(args.seed + gpu)
    torch.cuda.manual_seed_all(args.seed + gpu)
    
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = True


    if not args.eval:
        log_dir = os.path.expanduser('/'.join(args.save_dir.split('/')[:-1]))
        utils.cleanup_log_dir(log_dir)

        ## tensorboard
        tensorboard_dir = os.path.join(log_dir, "tensorboard")
        if not os.path.exists(tensorboard_dir):
            os.makedirs(tensorboard_dir)

        writer = SummaryWriter(log_dir=tensorboard_dir)
        

    args = init_path.get_logger_path(args)

    logging = get_logger(args, args.log_path)
    
    torch.cuda.set_device(gpu)
    

    
    ## initial path
    args = init_path.initialize_path(args)
    args = init_path.load_data_info(args)


    ## Model
    model = BC_MODEL(args)

    if not args.eval:
        model = model.cuda()
    else:
        model = model.cuda()


    action_criterion = nn.CrossEntropyLoss()
    obj_criterion = nn.CrossEntropyLoss()


    ## Agent
    agent = BC_Agent(
        args,
        model,
        action_criterion,
        obj_criterion,
        logging,
        gpu
    )

    ## load pretrained model
    agent, best_top1, start_epoch = load_pretrained_model(args, agent, gpu, logging)



    
    ## Testing
    if args.interactive_eval:
        vh_envs = utils_interactive_eval.connect_env(args, logging)
        tokenizer = get_pretrained_tokenizer(model_type=args.model_type, model_name_or_path=args.model_name_or_path)


    if args.eval:
        if args.interactive_eval:
            interactive_eval_success_rate = interactive_interface_fn(args, vh_envs, iteri=0, agent_model=agent, data_info=args.data_info, logging=logging, tokenizer=tokenizer)

        else:
            valset = Dataloader(args, 'val')
            valloader = data.DataLoader(valset, batch_size=int(args.num_mini_batch/2) if args.num_mini_batch>1 else args.num_mini_batch, shuffle=True, num_workers=4, drop_last=True, pin_memory=True)
            
            output = agent.run(valloader, 0, mode='eval')
            
            loss, action_loss, obj_loss, top1, action_top1, obj_top1 = output

            logging.info("Test loss: %.4f" % loss)
            logging.info("Test action_loss: %.4f" % action_loss)
            logging.info("Test obj_loss: %.4f" % obj_loss)
            
            logging.info("Test top1: %.4f" % top1)
            logging.info("Test action_top1: %.4f" % action_top1)
            logging.info("Test obj_top1: %.4f" % obj_top1)

        return 0

    else:
        trainset = Dataloader(args, 'train')
        valset = Dataloader(args, 'val')

        trainloader = data.DataLoader(trainset, batch_size=args.num_mini_batch, shuffle=True, num_workers=4, drop_last=True, pin_memory=True)
        valloader = data.DataLoader(valset, batch_size=int(args.num_mini_batch/2) if args.num_mini_batch>1 else args.num_mini_batch, shuffle=True, num_workers=4, drop_last=True, pin_memory=True)


            


    save_flag = True
    
    
    ## Train model
    agent.model.train()
    
    for j in range(start_epoch, args.train_epoch):

        output = agent.run(trainloader, j, mode='train')
        loss, action_loss, obj_loss, top1, action_top1, obj_top1 = output

        writer.add_scalar("Train/loss", loss, j)
        writer.add_scalar("Train/top1", top1, j)
        
        if (j > args.train_epoch) or ( (j % 1 == 0) and (j > 1) ):
        
            if args.interactive_eval:
                
                interactive_eval_success_rate = interactive_interface_fn(args, vh_envs, iteri=0, agent_model=agent, data_info=args.data_info, logging=logging, tokenizer=tokenizer)

                if interactive_eval_success_rate >= best_top1:
                    best_top1 = interactive_eval_success_rate
                    
                    logging.info('eval best top1 %.3f' % best_top1)
                    logging.info('save model to %s' % args.save_dir)

                    if save_flag:
                        save_model(args, agent, j, best_top1, is_best=True)

                writer.add_scalar("Eval/interactive_eval_success_rate", interactive_eval_success_rate, j)
                    

            else:
                output = agent.run(valloader, j, mode='eval')
                loss, action_loss, obj_loss, top1, action_top1, obj_top1 = output

                if top1 >= best_top1:
                    best_top1 = top1
                    
                    logging.info('eval best top1 %.3f' % best_top1)
                    logging.info('save model to %s' % args.save_dir)


                    if save_flag:
                        save_model(args, agent, j, best_top1, is_best=True)

                writer.add_scalar("Eval/loss", loss, j)
                writer.add_scalar("Eval/top1", top1, j)


    writer.close()


if __name__ == "__main__":
    main()
