import os
import torch
import random
import numpy as np
from config.all_config import AllConfig
# from torch.utils.tensorboard.writer import SummaryWriter
from datasets.data_factory import DataFactory
from model.model_factory import ModelFactory
from modules.metrics import t2v_metrics, v2t_metrics
from modules.loss import LossFactory
from modules.basic_utils import compute_layer_parameter_stats, load_pretrained_weights_auto_report, load_DiT
from trainer.trainer_txt_trunc_dm import Trainer
from modules.optimization import AdamW, get_cosine_schedule_with_warmup
from config.all_config import gen_log
import os


import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy('file_system')

def main():


    config = AllConfig()
    os.environ['TOKENIZERS_PARALLELISM'] = "false"
    # if not config.no_tensorboard:
    #     writer = SummaryWriter(log_dir=config.tb_log_dir)
    # else:
    writer = None


    if config.gpu is not None and config.gpu != '99':
        print('set GPU')
        os.environ["CUDA_VISIBLE_DEVICES"] = config.gpu
        torch.backends.cudnn.enabled = True
        torch.backends.cudnn.benchmark = True
        if not torch.cuda.is_available():
            raise Exception('NO GPU!')

    msg = f'model pth = {config.model_path}'
    gen_log(model_path=config.model_path, log_name='log_trntst', msg=msg)
    msg = f'\nconfig={config.__dict__}\n'
    gen_log(model_path=config.model_path, log_name='log_trntst', msg=msg)
    gen_log(model_path=config.model_path, log_name='log_trntst', msg='record all training and testing results')
    gen_log(model_path=config.model_path, log_name='log_ori_loss', msg='Prepare to record loss values per batch ')
    gen_log(model_path=config.model_path, log_name='log_dm_loss', msg='Prepare to record loss values per batch ')

    if config.seed >= 0:
        torch.manual_seed(config.seed)
        np.random.seed(config.seed)
        torch.cuda.manual_seed_all(config.seed)
        random.seed(config.seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    if config.huggingface:
        from transformers import CLIPTokenizer
        tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32", TOKENIZERS_PARALLELISM=False)
    else:
        from modules.tokenization_clip import SimpleTokenizer
        tokenizer = SimpleTokenizer()

    train_data_loader = DataFactory.get_data_loader(config, split_type='train')
    valid_data_loader  = DataFactory.get_data_loader(config, split_type='test')
    model = ModelFactory.get_model(config)

    for name, module in model.named_parameters():
        print(name)
    
    if config.metric == 't2v':
        metrics = t2v_metrics
    elif config.metric == 'v2t':
        metrics = v2t_metrics
    else:
        raise NotImplemented

    params_optimizer = list(model.named_parameters())
    clip_params = [p for n, p in params_optimizer if "clip." in n]
    pooling_params = [p for n, p in params_optimizer if "pool_frames" in n]
    diffusion_params = [p for n, p in params_optimizer if "net_d" in n]
    stochastic_params = [p for n, p in params_optimizer if "stochastic" in n]

    pretrain_optimizer_grouped_params = [
        {'params': clip_params, 'lr': config.clip_lr},
        {'params': pooling_params, 'lr': config.noclip_lr},
        {'params': diffusion_params, 'lr': config.dm_lr},
        {'params': stochastic_params, 'lr': config.noclip_lr}
    ]
    pretrain_optimizer = AdamW(pretrain_optimizer_grouped_params, weight_decay=0)


    optimizer_grouped_params = [
        {'params': clip_params, 'lr': config.clip_lr},
        {'params': pooling_params, 'lr': config.noclip_lr},
        {'params': diffusion_params, 'lr': config.dm_lr},
        {'params': stochastic_params, 'lr': config.noclip_lr}

    ]
    optimizer = AdamW(optimizer_grouped_params, weight_decay=config.weight_decay)
    num_training_steps = len(train_data_loader) * config.num_epochs
    num_warmup_steps = int(config.warmup_proportion * num_training_steps)

    if config.scheduler == 'constant':
        scheduler =None
    elif config.scheduler == 'default':
        scheduler = get_cosine_schedule_with_warmup(optimizer,
                                                    num_warmup_steps=num_warmup_steps,
                                                    num_training_steps=num_training_steps)
    else:
        raise NotImplementedError

    
    loss = LossFactory.get_loss(config.loss)

    if config.record_DM_param_stats:
        msg = compute_layer_parameter_stats(model)
        gen_log(model_path=config.model_path, log_name='log_trntst', msg=msg)


    trainer = Trainer(model=model,
                      metrics=metrics,
                      optimizer=optimizer,
                      pretrain_optimizer = pretrain_optimizer,
                      loss=loss,
                      config=config,
                      train_data_loader=train_data_loader,
                      valid_data_loader=valid_data_loader,
                      lr_scheduler=scheduler,
                      writer=writer,
                      tokenizer=tokenizer)

    if config.load_epoch is not None:


        if config.load_epoch > 0:
            xpool_or_tmass_state_dict = trainer.load_checkpoint("checkpoint-epoch{}.pth".format(config.load_epoch))
        else:
            xpool_or_tmass_state_dict = trainer.load_checkpoint("model_best.pth")
        gen_log(model_path=config.model_path, log_name='log_trntst', msg='Pre-trained model is loaded')




    if config.eval_before_train:
        trainer.validate()

    if config.training_recipe == 'pretrain+train':

        # pre-train stage, only update dm
        for name, param in model.named_parameters():
            if config.pretraining == 'fix_xpool':
                if 'net_d' in name or 'stochastic' in name:
                    param.requires_grad = True
                else:
                    param.requires_grad = False
            elif config.pretraining == 'fix_tmass':
                if 'net_d' in name:
                    param.requires_grad = True
                else:
                    param.requires_grad = False
            else:
                raise NotImplementedError
        trainer.pre_train()

        # training stage, depends on config.training
        for name, param in model.named_parameters():
            if config.training == 'fix_clip':
                if 'clip.' in name:
                    param.requires_grad = False
                else:
                    param.requires_grad = True
            elif config.training == 'joint_train':
                param.requires_grad = True
            elif config.training == 'fix_xpool':
                if 'net_d' in name or 'stochastic' in name:
                    param.requires_grad = True
                else:
                    param.requires_grad = False
            elif config.training == 'fix_tmass':
                if 'net_d' in name:
                    param.requires_grad = True
                else:
                    param.requires_grad = False
            elif config.training == 'fix_dmtxt':
                if 'net_d' in name:
                    param.requires_grad = False
                else:
                    param.requires_grad = True
            else:
                raise NotImplementedError
        trainer.train()

    elif config.training_recipe == 'pretrain':

        # pre-train stage
        for name, param in model.named_parameters():
            if 'net_d' in name:
                param.requires_grad = True
            else:
                param.requires_grad = False
        trainer.pre_train()

    elif config.training_recipe == 'train':
        # training stage, depends on config.training
        for name, param in model.named_parameters():
            if config.training == 'fix_clip':
                if 'clip.' in name:
                    param.requires_grad = False
                else:
                    param.requires_grad = True
            elif config.training == 'joint_train':
                param.requires_grad = True
            elif config.training == 'fix_xpool':
                if 'net_d' in name or 'stochastic' in name:
                    param.requires_grad = True
                else:
                    param.requires_grad = False
            elif config.training == 'fix_tmass':
                if 'net_d' in name:
                    param.requires_grad = True
                else:
                    param.requires_grad = False
            elif config.training == 'fix_dmtxt':
                if 'net_d' in name:
                    param.requires_grad = False
                else:
                    param.requires_grad = True
            else:
                raise NotImplementedError
        trainer.train()
    else:
        raise NotImplementedError

    if config.record_DM_param_stats:
        msg = compute_layer_parameter_stats(model)
        gen_log(model_path=config.model_path, log_name='log_trntst', msg=msg)


if __name__ == '__main__':
    main()
