
from clus.models.utils.train_state import *
from clus.models.model.basic import *
from clus.models.model.cdm import *
from clus.env.offline import *
from clus.env.cl_scenario import ContinualScenario , MultiTaskSceario
from clus.env.metaworld_env import MMEvaluator,CWEvaluator
from clus.utils.utils import create_directory_if_not_exists

from clus.trainer.base_trainer import ContinualTrainer
from clus.env.continual_config import *
from clus.env.metaworld_env import get_task_list_equal_normal, configs_task_list
from clus.models.model.ewccdm import EWCDiffusion

import sys
class DualStream:
    def __init__(self, stream1, stream2):
        self.stream1 = stream1
        self.stream2 = stream2

    def write(self, message):
        self.stream1.write(message)
        self.stream2.write(message)

    def flush(self):
        self.stream1.flush()
        self.stream2.flush()
# Parse arguments
import argparse
parser = argparse.ArgumentParser(description='L2M based continual learner trianing function.')
parser.add_argument('-d', '--debug', type=bool, help='Set experiment to debug mode', default=False)
parser.add_argument('-e', '--env', type=str, help='mmworld, kitchen, libero', default='kitchen')
parser.add_argument('-sp', '--spec', type=str, help='none 10 20 10h 20h', default='complete')

parser.add_argument('-ep', '--eval_episodes', type=int, help='default to 3', default=3)
parser.add_argument('-r', '--rank', type=int, help='Adapter basic rank 3', default=1)
parser.add_argument('-s', '--shot', type=int, help='shot count per each dataset', default=None)

parser.add_argument('-al', '--algo', type=str, help='Seq Tuning Algorithm', default='seq')
parser.add_argument('-id', '--save_id', type=str, help='save path under default path', default='test')
parser.add_argument('-seed', '--seed', type=int, help='make seed for all experiments', default=None)

parser.add_argument('-lr', '--lr', type=float, help='learning rate', default=1e-4)
parser.add_argument('-alpha', '--ewc_alpha', type=float, help='reg_ratio', default=1e0)
parser.add_argument('-epoch', '--epoch', type=int, help='epochs', default=3000)


parser.add_argument('-replay', '--replay', type=int, help='replay per phase', default=None)
parser.add_argument('-ec', '--eval_count', type=int, help='replay per phase', default=1)


default_path = './data/seq_expriments'
args = parser.parse_args()
if __name__ == '__main__' :
    print('args' , args)
    debug_flag = args.debug
    env_type = args.env 
    eval_episodes = int(args.eval_episodes)
    rank = int(args.rank)
    algo = args.algo
    lr = args.lr
    epoch = args.epoch
    few_shot_len = args.shot
    replay= args.replay
    if algo == 'mtseq' :
        replay = -1

    if algo == 'er' and replay is None :
        print("er algorithm requires replay")
        raise NotImplementedError

    full_path = f"{default_path}/{args.algo}/{args.env}/{args.save_id}"
    logging_path_base = f"{full_path}/training_log.log"

    if debug_flag == True :
        full_path = f"{default_path}/debug"
        logging_path_base = f"{default_path}/debug/training_log.log"
    else :
        create_directory_if_not_exists(full_path)
        # create log_file

    optim_config = {
        'optimizer_cls' : optax.adamw,
        'optimizer_kwargs' : {
            'learning_rate' : lr, # 1e-4 default
            'weight_decay' : 1e-4,
        },
    } 
    if env_type == 'kitchen' :
        optim_config['optimizer_kwargs']['learning_rate'] = 5e-4

    diffusion_model_config = {
        'model_cls' : ConditionalDiffusion, # EWCDiffusion
        'model_kwargs' : {
            'input_config' : None,
            'optimizer_config' : optim_config,
            'model_config' : {
                'model_cls' :FlaxDenoisingBlockMLP,
                'model_kwargs' : {
                    'dim' : 512,
                    'n_blocks' : 4,
                    'context_emb_dim' : 512,
                    'dropout' : 0.1,
                }
            }, 
            'clip_denoised' : False,
            'diffusion_step' : 64,
        },
    }
    
    ewc_config = None
    if algo == 'ewc' or algo == 'l2':
        diffusion_model_config['model_cls'] = EWCDiffusion
        # diffusion_model_config['model_cls'] = ConditionalDiffusion
        ewc_config = {
            'ewc_mode' : algo, # L2 for l2 regularization, fisher for fisher regularization
            'ewc_ratio' : args.ewc_alpha,
            'fisher_epoch' : 10,
        }
    elif algo == 'seq' or algo == 'er' or algo == 'seq0' :
        diffusion_model_config['model_cls'] = ConditionalDiffusion
    
    # load data and initialize the model
    exp_config = {
        'phase_epoch' : 20000,
        'eval_epoch' : 10000,
        'batch_size' : 1024,
        'eval_env' : True if eval_episodes > 0 else False,
        'base_path' : full_path, # base path for saving items
        'phase_optim' : 're_initialize',
        'replay_method' : 'random',  # 'kmeans' or 'random' or 'sequential'
        # 'phase_batch_sz' : 0, # No Replay
        'phase_batch_sz' : replay,
        'init_model_path' : './data/l2ms/kitchen_base/models/model_0.pkl',
    }
        
    ## Continual Scenario
    dataloader_config = {
        'dataloader_cls' : MemoryPoolDataloader,
        'dataloader_kwargs' :{
            'skill_embedding_path' : 'data/continual_dataset/evolving_world/mm_lang_embedding.pkl',
            'skill_exclude' : None,
            'semantic_flag' : True, 
        }
    }
    if env_type == 'kitchen' :
        print("kitchen_evaluation")
        exp_config['phase_epoch'] = epoch
        exp_config['eval_epoch'] = epoch 
        exp_config['init_model_path'] = 'data/pre_trained_models/evolving_kitchen/diffusion/model_0.pkl'
        state_dim=572

        scenario_cls = ContinualScenario
        phase_configures = KITCHEN_MINIMAL_TO_FULL_24
        # if args.spec == 'd2' :
        #     phase_configures = KITCHEN_MTF_24_D2
        # elif args.spec == 'd4' :
        #     phase_configures = KITCHEN_MTF_24_D4
        # elif args.spec == 'd2m' :
        #     phase_configures = KITCHEN_MTF_24_D2
        #     scenario_cls = MultiTaskSceario
        # elif args.spec == 'd4m' :
        #     phase_configures = KITCHEN_MTF_24_D4
        #     scenario_cls = MultiTaskSceario
        # elif args.spec == 'd0m' :
        #     phase_configures = KITCHEN_MINIMAL_TO_FULL_24
        #     scenario_cls = MultiTaskSceario
        if args.spec == 'complete' :
            phase_configures = EK_COMPLETE
        elif args.spec == 'semi' :
            phase_configures = EK_SEMI
        elif args.spec == 'incomplete' :
            phase_configures = EK_INCOMPLETE
        elif args.spec == 'compret' :
            phase_configures = EK_COMP_RET
        elif args.spec == 'incompret' : 
            phase_configures = EK_INCOMP_RET
        elif args.spec == 'ucomp' :
            phase_configures = UEK_COMPLETE
        elif args.spec == 'uincom' :
            phase_configures = UEK_INCOMPLETE
        else :
            print("not supported spec")
            raise NotImplementedError


        dataloader_config = {
                'dataloader_cls' : MemoryPoolDataloader,
                'dataloader_kwargs' :{
                    'skill_embedding_path' : 'data/continual_dataset/evolving_kitchen/kitchen_lang_embedding.pkl',
                    'skill_exclude' : None,
                    'semantic_flag' : True, 
                }
            }   
        from clus.env.kitchen import KitchenEvaluator
        # phase_configures = KITCHEN_MINIMAL_TO_FULL_24_D
        if args.spec == 'compret' or args.spec == 'incompret'or \
            args.spec == 'ucomp' or args.spec == 'uincom':
            continual_scenario = scenario_cls(
                dataloader_config=dataloader_config,
                phase_configures=phase_configures,
                evaluator=KitchenEvaluator(
                    phase_configures=EK_COMPLETE,
                    eval_mode='obs',
                    eval_episodes=3,
                ),
            )
        elif args.spec != 'semi' :
            continual_scenario = scenario_cls(
                dataloader_config=dataloader_config,
                phase_configures=phase_configures,
                evaluator=KitchenEvaluator(
                    phase_configures=phase_configures,
                    eval_mode='obs',
                    eval_episodes=3,
                ),
            )
        else : # this is semi
            continual_scenario = scenario_cls(
                dataloader_config=dataloader_config,
                phase_configures=phase_configures,
                evaluator=KitchenEvaluator(
                    phase_configures=phase_configures[:10],
                    eval_mode='obs',
                    eval_episodes=3,
                ),
            )

    elif env_type == 'mmworld' : 
        print("mmworld_evaluation")
        exp_config['init_model_path'] = 'data/pre_trained_models/evolving_world/diffusion/model_0.pkl'
        exp_config['phase_epoch'] = epoch 
        exp_config['eval_epoch'] = epoch
        state_dim = 652

        scenario_cls = ContinualScenario
        if args.spec == 'complete' :
            phase_configures = MW_COMPLETE
        elif args.spec == 'semi' :
            phase_configures = MW_SEMI_COMPLETE
        elif args.spec == 'incomplete' :
            phase_configures = MW_INCOMPLETE
        elif args.spec == 'ret' :
            phase_configures = None
        else :
            print("not supported spec")
            raise NotImplementedError

        if args.spec == 'semi' or args.spec == '20hs2' or args.spec == '20hsm' :
            continual_scenario = scenario_cls(
                dataloader_config=dataloader_config,
                phase_configures=phase_configures,
                evaluator=MMEvaluator(configs_task_list(phase_configures[:10]),
                    eval_mode='obs', 
                    eval_episodes=eval_episodes,
                ),
            )
        elif debug_flag == False :
            continual_scenario = ContinualScenario(
                dataloader_config=dataloader_config,
                phase_configures=phase_configures,
                evaluator=MMEvaluator(configs_task_list(phase_configures),
                    eval_mode='obs', 
                    eval_episodes=eval_episodes,
                ),
            )
        else :
            continual_scenario = ContinualScenario(
                dataloader_config=dataloader_config,
                phase_configures=MM_EASY_TO_HARD_HS_U20,
                evaluator=MMEvaluator(get_task_list_equal_normal(only_normal=False)[:1],
                    eval_mode='obs', 
                    eval_episodes=eval_episodes,
                ),
            )
    else :
        print( f"env_type : {env_type} is not supported")
        raise NotImplementedError
    
    if args.seed is not None :  
        np.random.seed(args.seed)
        random.seed(args.seed)

    # exp_config['phase_epoch'] = 0
    trainer = ContinualTrainer(
        continual_scenario=continual_scenario,
        model_config=diffusion_model_config,
        exp_config=exp_config,
        adapt_from_zero=False if algo != 'seq0' else True,
        ewc_config=ewc_config,
    )

    logging_path = logging_path_base
    file_log = open( logging_path, "w")
    dual_stream = DualStream(sys.stdout, file_log)
    sys.stdout = dual_stream

    print(f'diffusion model config : {diffusion_model_config}')
    print(f'experiment config : {exp_config}')
    if dataloader_config is not None :
        print(f'dataloader config : {dataloader_config}')

    trainer.continual_train()

    sys.stdout = sys.__stdout__
    file_log.close()