import hydra
import tempfile
import subprocess
import os
import copy
import d4rl
import random
import time
import shutil

import torch.multiprocessing as mp

from omegaconf import DictConfig, OmegaConf
from torch.utils.data import DataLoader, random_split
from sgcrl.utils.seeding import set_seed
from sgcrl.utils.imports import instantiate_class, get_class
from sgcrl.data.dbs.on_disk import DiskPythonObjectDB, load_model
from sgcrl.utils.logger import TensorBoardLogger
from sgcrl.evaluation.d4rl_evaluation import parallel_evaluation_loop
from sgcrl.utils.conf import omegaconf_to_conf
from sgcrl.data.torch_datasets.datasets import ContrastiveTripletsDataset, ContrastiveTripletsDatasetVisual, TransformerDataset, PytorchEpisodeFrameDatasetComplexRelabel
from sgcrl.algorithms.policies.qphil.trainers import train_quantizer, train_high_value, train_transformer, train_low_hiql_goal, train_low_hiql_subgoal
from sgcrl.utils.plotting import plot_maze, plot_godot, plot_visual_maze
from sgcrl.data.torch_datasets.relabel import NextTokenRelabeller, BatchRelabeller, TokenRelabeller
from sgcrl.models.quantizer import Tokenizer, special_tokens
from sgcrl.models.dual_policy import DualPolicy
from omegaconf import DictConfig, open_dict, ListConfig
from sgcrl.data.dbs.on_disk import PytorchOnDiskEpisodesDB
from sgcrl.evaluation.godot_evaluation import start_server_with_model_idx, single_evaluation

def train(training_db, model_db, cfg: DictConfig, logger, evaluation_logger):

    # Init env variables
    cfg = omegaconf_to_conf(cfg)  # Avoid costly Hydra config lookups by dumping config data into namespace
    device = cfg.device
    is_godot = ('map_large' in cfg.env_name)
    is_visual = ('topview' in cfg.env_name)
    is_antmaze = (not is_visual) and ('antmaze' in cfg.env_name)
    is_kitchen = ('kitchen' in cfg.env_name)

    #############
    # Quantizer #
    #############
    quantizer = instantiate_class(cfg.quantizer).to(device)
    if cfg.train_quantizer:
        optimizer_quantizer = get_class(cfg.optimizer_quantizer)(quantizer.parameters(), cfg.optimizer_quantizer.lr)

        # Dataset
        if is_visual:
            contrastive_dataset = ContrastiveTripletsDatasetVisual(training_db, cfg.quantizer_window, target_slice=cfg.target_slice, cache_size=1,device=cfg.device)
        else:
            contrastive_dataset = ContrastiveTripletsDataset(training_db, cfg.quantizer_window, cfg.random_negative, keys_to_tokenize=cfg.keys_to_tokenize, is_godot=is_godot)

        # Dataloader
        contrastive_dataloader = DataLoader(contrastive_dataset, batch_size=cfg.batch_size_quantizer, num_workers=cfg.num_workers, 
                                                pin_memory=True, persistent_workers=(cfg.num_workers > 0))
        
        # Training
        quantizer = train_quantizer(quantizer, optimizer_quantizer, cfg.max_epoch_quantizer, contrastive_dataloader, device, model_db, logger, 
                    cfg.offset, cfg.norm, cfg.contrastive_coef, cfg.commit_coef, cfg.reconstruction_coef, cfg.save_every_quantizer, noise=cfg.noise)

        # Tokenizer
        tokenizer = Tokenizer(quantizer, model_db._directory, number_of_tokens=cfg.quantizer.codebook_size, offset=cfg.offset, norm=cfg.norm, 
                              device=device, keys_to_tokenize=cfg.keys_to_tokenize, is_visual=is_visual)

        # Plotting
        if is_antmaze:
            background_file = f"{cfg.log_dir}/background_tokens.pickle"
            plot_maze(tokenizer, cfg.env_name, device, points=None, background_file=background_file)
        elif is_godot:
            background_file = f"{cfg.log_dir}/background_tokens.pickle"
            plot_godot(tokenizer, device, 1.0, background_file)
        elif is_kitchen:
            dataset = PytorchEpisodeFrameDatasetComplexRelabel(training_db, 1)
            dataloader = DataLoader(dataset, batch_size=1024, shuffle=True, num_workers=1, pin_memory=True, persistent_workers=False)
            tokenizer.init__tokens(dataloader)
        elif is_visual:
            background_file = f"{cfg.log_dir}/background_tokens.pickle"
            plot_visual_maze(tokenizer, background_file)

        # Save model
        m = copy.deepcopy(tokenizer).cpu()
        model_db.push("tokenizer", m)
    else:
        # Load quantizer and tokenizer
        # quantizer = load_model(cfg.env_name, cfg.algorithm_name, cfg.seed, device, 'quantizer') # Not necessarily the one of tokenizer, because not necessarily the last gradient step
        tokenizer = load_model(cfg.env_name, cfg.algorithm_name, cfg.seed, device, 'tokenizer')
    tokenizer.eval()

    special_tokens['EOS_TOKEN'] = tokenizer._number_of_tokens
    special_tokens['SOS_TOKEN'] = tokenizer._number_of_tokens + 1 
    special_tokens['PADDING_VALUE'] = tokenizer._number_of_tokens + 2

    # print(f'EOS_TOKEN: {special_tokens['EOS_TOKEN']}')
    # print(f'SOS_TOKEN: {special_tokens['SOS_TOKEN']}')
    # print(f'PADDING_VALUE: {special_tokens['PADDING_VALUE']}')

    #######################
    # High value function #
    #######################
    if cfg.train_high_value:
        tokenizer = tokenizer.to('cpu')
        dataset_episodes = TransformerDataset(
            pytorch_episodes_db=training_db, 
            keys_to_tokenize=cfg.keys_to_tokenize, 
            tokenizer=tokenizer, 
            augment_dataset_prob=cfg.augment_transformer_data_prob, 
            remove_cycles=cfg.remove_cycles,
            stitching=False,
            is_godot=is_godot
        )

        train_dataset_tokenized_episode, val_dataset_tokenized_episode = random_split(
            dataset_episodes, (int(cfg.validation_split * len(dataset_episodes)), 
            len(dataset_episodes) - int(cfg.validation_split * len(dataset_episodes)))
        )
        dataloader_episodes_train = DataLoader(train_dataset_tokenized_episode, batch_size=cfg.batch_size_transformer, num_workers=cfg.num_workers, 
                                         pin_memory=True, persistent_workers=(cfg.num_workers > 0))
        dataloader_episodes_val   = DataLoader(val_dataset_tokenized_episode,   batch_size=cfg.batch_size_transformer, num_workers=cfg.num_workers, 
                                        pin_memory=True, persistent_workers=(cfg.num_workers > 0))
        high_policy, hvf1, hvf2 = train_high_value(
            cfg=cfg,
            tokenzier=tokenizer,
            high_reward_scale=cfg.high_reward_scale,
            high_discount=cfg.high_discount,
            high_expectile=cfg.high_expectile,
            high_beta=cfg.high_beta,
            high_clip_score=cfg.high_clip_score,
            high_v_update_period=cfg.high_v_update_period,
            high_policy_update_period=cfg.high_policy_update_period,
            high_target_update_period=cfg.high_target_update_period,
            polyak_coef=cfg.high_polyak_coef,
            epochs=cfg.high_max_epochs,
            max_gradient_step=cfg.high_max_gradient_step,
            train_dataloader=dataloader_episodes_train, 
            val_dataloader=dataloader_episodes_val, 
            special_tokens=special_tokens, 
            model_db=model_db, 
            logger=logger, 
            save_every=cfg.high_save_every, 
            device=device
        )
        tokenizer = tokenizer.to(device)
        high_value = hvf1
    else:
        high_value = load_model(cfg.env_name, cfg.algorithm_name, cfg.seed, device, "high_value")
    high_value.eval()

    # Verify here that the high value is well learned
    exit()

    ###############
    # Transformer #
    ###############
    if cfg.train_transformer:
        tokenizer = tokenizer.to('cpu')
        dataset_episodes = TransformerDataset(
            pytorch_episodes_db=training_db, 
            keys_to_tokenize=cfg.keys_to_tokenize, 
            tokenizer=tokenizer, 
            augment_dataset_prob=cfg.augment_transformer_data_prob, 
            remove_cycles=cfg.remove_cycles,
            stitching=cfg.stitching,
            is_godot=is_godot
        )

        train_dataset_tokenized_episode, val_dataset_tokenized_episode = random_split(
            dataset_episodes, (int(cfg.validation_split * len(dataset_episodes)), 
            len(dataset_episodes) - int(cfg.validation_split * len(dataset_episodes)))
        )
        dataloader_episodes_train = DataLoader(train_dataset_tokenized_episode, batch_size=cfg.batch_size_transformer, num_workers=cfg.num_workers, 
                                         pin_memory=True, persistent_workers=(cfg.num_workers > 0))
        dataloader_episodes_val   = DataLoader(val_dataset_tokenized_episode,   batch_size=cfg.batch_size_transformer, num_workers=cfg.num_workers, 
                                        pin_memory=True, persistent_workers=(cfg.num_workers > 0))
        
        transformer = get_class(cfg.transformer)(num_classes=special_tokens['EOS_TOKEN'] + 2, max_output_length=cfg.transformer.max_sequence_length).to(device)
        transformer.set_special_tokens(special_tokens['SOS_TOKEN'], special_tokens['EOS_TOKEN'], special_tokens['PADDING_VALUE'])
        optimizer_transformer = get_class(cfg.optimizer_transformer)(transformer.parameters(), cfg.optimizer_transformer.lr)
        transformer = train_transformer(
            transformer=transformer, 
            optimizer=optimizer_transformer, 
            epochs=cfg.max_epoch_transformer,
            train_dataloader=dataloader_episodes_train, 
            val_dataloader=dataloader_episodes_val,
            special_tokens=special_tokens,
            model_db=model_db,
            logger=logger, 
            save_every=cfg.save_every_transformer, 
            device=device
        )
        tokenizer = tokenizer.to(device)
    else:
        transformer = load_model(cfg.env_name, cfg.algorithm_name, cfg.seed, device, "transformer")
    transformer.eval()

    ##########
    # Policy #
    ##########
    # Get the goal conditioned pi(a|s,g) for goal reaching at the last token, take this from hiql training
    if cfg.train_goal_policy:
        raise NotImplementedError(0)
    else:
        low_level_policy_goals = load_model(cfg.env_name, cfg.algorithm_name, cfg.seed, 'cpu', cfg.goal_model_name)
    
    # Get the subgoal conditioned policy pi(a|s,i) for subgoal reaching
    if cfg.train_subgoal_policy:

        # Get subgoal policy
        low_level_policy_subgoals = instantiate_class(cfg.low_level_policy_subgoals).to(device)
        optimizer_low_level_policy_subgoals = get_class(cfg.optimizer_low_level_policy_subgoals)(low_level_policy_subgoals.parameters(), cfg.optimizer_low_level_policy_subgoals.parameters)
        
        # Get dataloader
        tokenizer = tokenizer.to('cpu')
        batch_relabeller = BatchRelabeller([instantiate_class(r) for r in (cfg.batch_relabellers or [])] + [TokenRelabeller(tokenizer, 'cpu', cfg.keys_to_tokenize)])
        dataset = PytorchEpisodeFrameDatasetComplexRelabel(
            pytorch_episodes_db=training_db, 
            frame_size = 2, 
            cache_size = 1 if is_visual else 100000, 
            episode_relabellers=[instantiate_class(r) for r in (cfg.relabellers or [])] + [NextTokenRelabeller(tokenizer, 'cpu', cfg.keys_to_tokenize)]
        )
        dataloader = DataLoader(
            dataset, 
            batch_size=cfg.batch_size_low_level_policy, 
            shuffle=True, 
            pin_memory=True, 
            num_workers=cfg.num_workers, 
            persistent_workers=(cfg.num_workers > 0), 
            collate_fn=batch_relabeller
        )

        # Train low subgoal policy
        train_low_hiql_subgoal(
            model=low_level_policy_subgoals,
            transformer=transformer, 
            tokenizer=tokenizer,
            use_obs_representation=cfg.use_obs_representation,
            keys_to_tokenize=cfg.keys_to_tokenize,
            low_level_policy_goals=low_level_policy_goals,
            save_dual_policy=True,
            dual_policy_relabellers=cfg.dual_policy_relabellers,
            optimizer=optimizer_low_level_policy_subgoals, 
            model_db=model_db,
            dataloader=dataloader, 
            device=device,  
            save_every=cfg.save_every_low_level_policy, 
            reward_scale=cfg.hiql_parameters.reward_scale,
            discount=cfg.hiql_parameters.discount,
            expectile=cfg.hiql_parameters.expectile,
            beta=cfg.beta,
            clip_score=cfg.hiql_parameters.clip_score,
            v_update_period=cfg.hiql_parameters.v_update_period,
            policy_update_period=cfg.hiql_parameters.policy_update_period,
            target_update_period=cfg.hiql_parameters.target_update_period,
            max_gradient_step=cfg.max_gradient_step,
            polyak_coef=cfg.hiql_parameters.polyak_coef,
            logger=logger,
            serial_evaluation=cfg.evaluation.serial, 
            cfg_evaluation=cfg.evaluation, 
            evaluation_logger=evaluation_logger,
            is_godot=is_godot
        )
        tokenizer = tokenizer.to(device)
    else:
        low_level_policy_subgoals = load_model(cfg.env_name, cfg.algorithm_name, cfg.seed, 'cpu', 'low_subgoal_policy')
    
    # Set dual policy
    dual_policy = DualPolicy(
        low_level_policy_subgoal=low_level_policy_subgoals, 
        low_level_policy_goal=low_level_policy_goals, 
        high_level_policy=transformer, 
        tokenizer=tokenizer, 
        sos_token=special_tokens['SOS_TOKEN'], 
        eos_token=special_tokens['EOS_TOKEN'], 
        keys_to_tokenize=cfg.keys_to_tokenize, 
        frame_relabellers=[instantiate_class(r) for r in (cfg.dual_policy_relabellers or [])]
    )
    m = copy.deepcopy(dual_policy).cpu()
    model_db.push("model", m)

@hydra.main(version_base=None,config_path="yaml/qphil", config_name="antmaze-extreme")
def main(cfg: DictConfig) -> None:

    # Launch training
    is_godot = ('map_large' in cfg.env_name)
    if is_godot:
        # Compute how many episodes per players needed Godot-side for evaluation and choose port before seeding
        with open_dict(cfg):
            cfg.application.config_configuration.decorators.config.value.n_episodes_per_player = int(cfg.n_test_episodes / cfg.n_players)
            cfg.application.port = random.randint(8000, 8900)
        print(f"Godot side each players will perform {cfg.application.config_configuration.decorators.config.value.n_episodes_per_player} episodes")
        print(f"Launching godot on port {cfg.application.port}")

    # Setting utils
    set_seed(cfg.seed)
    training_db = instantiate_class(cfg.episodes_reader)
    model_db = DiskPythonObjectDB(cfg.log_dir + "/models")
    logger = TensorBoardLogger(log_dir=f'{cfg.log_dir}/train', prefix="train", max_cache_size=1000)
    TensorBoardLogger.log_params(cfg.log_dir, cfg)
    evaluation_logger = TensorBoardLogger(log_dir=f'{cfg.log_dir}/train_evaluation', prefix='train_evaluation', max_cache_size=1)

    # Launch training
    if is_godot:

        process = None
        eval_process = None
        if "evaluation" in cfg and not cfg.evaluation is None and cfg.evaluation.parallel:
            learning_process = mp.Process(target=train, args=(training_db, model_db, OmegaConf.create(OmegaConf.to_container(cfg, resolve=True)), logger, evaluation_logger))
            learning_process.daemon = False  # is it problematic ?
            learning_process.start()
            evaluation_epoch = -1
            while learning_process.is_alive():
                tdir = tempfile.TemporaryDirectory().name
                print("-- Launching evaluation in ", tdir)
                os.makedirs(tdir)
                evaluation_db = PytorchOnDiskEpisodesDB(tdir, use_pickle=True)
                db = {}
                for k in range(0, cfg.n_players):
                    db["player_" + str(k)] = evaluation_db
                last_model_idx = model_db.size("model") - 1
                while last_model_idx <= evaluation_epoch:
                    time.sleep(1)
                    last_model_idx = model_db.size("model") - 1
                evaluation_epoch = last_model_idx

                process = mp.Process(target=start_server_with_model_idx,
                                    args=(db, cfg.application, model_db, evaluation_epoch))
                process.daemon = True
                process.start()
                time.sleep(10)  # letting python server startup before launching exe

                exe = cfg.run.exe
                headless = cfg.run.headless
                if headless:
                    exe = [exe, " --headless"]
                else:
                    exe = [exe]
                exe.extend([" --drainc_port=" + str(cfg.application.port), "--drainc_host=" + str(cfg.application.host), "--map_name=" + cfg.env_name])
                print(f"Launching godot side: {exe}")
                exe_process = subprocess.Popen(exe)

                print("==== Evaluating on ", len(evaluation_db), " episodes")
                single_evaluation(evaluation_db, cfg.evaluation, evaluation_logger, evaluation_epoch)
                process.kill()
                shutil.rmtree(tdir)
        else:
            train(training_db, model_db, cfg, logger, evaluation_logger)
    else:
        # Launch the parallel evalulation process
        process=None
        if "evaluation" in cfg and not cfg.evaluation is None and cfg.evaluation.parallel:
            if cfg.evaluation.parallel:
                process = parallel_evaluation_loop(model_db, cfg.evaluation, evaluation_logger)
        
        # Train
        train(training_db, model_db, cfg, logger, evaluation_logger)

        # Wait for parallel evaluation process to end
        print("Training ended.")
        if not process is None:
            print("Waiting for evaluation process to end...")
            process.join()

if __name__ == "__main__":
    mp.set_start_method("spawn")
    mp.set_sharing_strategy('file_system')
    main()