import argparse
import torch
from accelerate import Accelerator, DeepSpeedPlugin
from accelerate import DistributedDataParallelKwargs
from torch import nn, optim
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

from models import MMRCR
from data_provider.data_factory import data_provider_training
import time
import random
import numpy as np
import os
import yaml

from Parrot.models.parrot_model import ParrotConditionPredictionModel

os.environ['CURL_CA_BUNDLE'] = ''
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:64"

from utils.tools import EarlyStopping, vali, _load_model_args

parser = argparse.ArgumentParser(description='MM-RCR')

fix_seed = 2024
random.seed(fix_seed)
torch.manual_seed(fix_seed)
np.random.seed(fix_seed)

# basic config
parser.add_argument('--task_name', type=str, required=True, default='multi-RCR',
                    help='task name, options:[multi-RCR]')
parser.add_argument('--seed', type=int, default=2024, help='random seed')
parser.add_argument('--debug', action='store_true',
                    help='whether it is the debug mode')
parser.add_argument('--comment', type=str, default='', help='other comments that need to be recorded')
parser.add_argument('--llm_pretrained_path', type=str, default=None, help='LLM pretrained path')
parser.add_argument('--llama_pretrained_path', type=str, required=True, default=None, help='llama pretrained path')

# data loader
parser.add_argument('--data', type=str, default='USPTO-Condition', help='dataset type')
parser.add_argument('--checkpoints', type=str, default='./checkpoints/', help='location of model checkpoints')
parser.add_argument('--logs', type=str, default='./logs/', help='location of log files')
parser.add_argument('--config_path', type=str, default='./configs/uspto_condition_simcorpus.yaml', help='activation')

# forecasting task
parser.add_argument('--seq_len', type=int, default=96, help='input sequence length')
parser.add_argument('--label_len', type=int, default=48, help='start token length')
parser.add_argument('--pred_len', type=int, default=96, help='prediction sequence length')

# model define
parser.add_argument('--enc_in', type=int, default=7, help='encoder input size')
parser.add_argument('--dec_in', type=int, default=7, help='decoder input size')
parser.add_argument('--c_out', type=int, default=7, help='output size')
parser.add_argument('--d_model', type=int, default=16, help='dimension of model')
parser.add_argument('--n_heads', type=int, default=8, help='num of heads')
parser.add_argument('--e_layers', type=int, default=2, help='num of encoder layers')
parser.add_argument('--d_layers', type=int, default=1, help='num of decoder layers')
parser.add_argument('--d_ff', type=int, default=32, help='dimension of fcn')
parser.add_argument('--moving_avg', type=int, default=25, help='window size of moving average')
parser.add_argument('--factor', type=int, default=1, help='attn factor')
parser.add_argument('--dropout', type=float, default=0.1, help='dropout')
parser.add_argument('--activation', type=str, default='gelu', help='activation')
parser.add_argument('--output_attention', action='store_true', help='whether to output attention in encoder')
parser.add_argument('--patch_len', type=int, default=16, help='patch length')
parser.add_argument('--stride', type=int, default=8, help='stride')

parser.add_argument('--projection_type', type=str, choices=["reprogramming","perceiver","mlp"], default="perceiver", help='Multimodal feature fusion methods')
parser.add_argument("--rxn_text_flag", default=False, action='store_true', help='whether to use the reaction SMILES as the textual content in prompt')
parser.add_argument("--rxn_source_flag", default=False, action='store_true', help='whether to use reaction SMILES tokens as source embedding in modality projection layer')
parser.add_argument("--only_parrot", default=False, action='store_true', help='whether to only take the SMILES modality')
parser.add_argument("--only_corpus", default=False, action='store_true', help='whether to only take the similar corpus modality')
parser.add_argument("--use_graph", default=False, action='store_true', help='whether to add the graph modality')
parser.add_argument("--use_fp", default=False, action='store_true', help='whether to replace the Parrot encoder with Morgan fingerprints')

# optimization
parser.add_argument('--num_workers', type=int, default=10, help='data loader num workers')
parser.add_argument('--itr', type=int, default=1, help='experiments times')
parser.add_argument('--train_epochs', type=int, default=10, help='train epochs')
parser.add_argument('--batch_size', type=int, default=16, help='batch size of train input data')
parser.add_argument('--eval_batch_size', type=int, default=8, help='batch size of model evaluation')
parser.add_argument('--patience', type=int, default=10, help='early stopping patience')
parser.add_argument('--learning_rate', type=float, default=0.0001, help='optimizer learning rate')
parser.add_argument('--des', type=str, default='test', help='exp description')
parser.add_argument('--lradj', type=str, default='COS', help='adjust learning rate')
parser.add_argument('--pct_start', type=float, default=0.2, help='pct_start')
parser.add_argument('--use_amp', action='store_true', help='use automatic mixed precision training', default=False)
parser.add_argument('--llm_layers', type=int, default=32)

args_part = parser.parse_args()
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
deepspeed_plugin = DeepSpeedPlugin(hf_ds_config='./ds_config_zero2.json')
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs], deepspeed_plugin=deepspeed_plugin)

config = yaml.load(open(args_part.config_path, "r"),
                Loader=yaml.FullLoader)
parrot_pretrained_path=config["model_args"]['pretrained_path']
args = _load_model_args(parrot_pretrained_path)

args.update_from_dict(vars(args_part))

for ii in range(args.itr):
    # setting record of experiments
    setting = '{}_data{}_dm{}_df{}_bs{}_lr{}_llmlayer{}_epoch{}_debug{}_proj{}_txtflag{}_srcflat{}_{}'.format(
        args.task_name,
        args.data, 
        args.d_model,
        args.d_ff,
        args.batch_size,
        args.learning_rate,
        args.llm_layers,
        args.train_epochs,
        args.debug,
        args.projection_type,
        args.rxn_text_flag,
        args.rxn_source_flag,
        args.comment
    )
    accelerator.print('**********settings**********')
    accelerator.print(f'rxn_text_flag: {args.rxn_text_flag}')
    accelerator.print(f'projection_type: {args.projection_type}')
    accelerator.print(f'rxn_source_flag: {args.rxn_source_flag}')
    accelerator.print(f'config_path: {args.config_path}')
    accelerator.print(f'llama_pretrained_path: {args.llama_pretrained_path}')

    save_path = f'{args.logs}/{setting}'
    writer = SummaryWriter(log_dir=save_path)

    train_df, eval_df, condition_label_mapping, model_args = data_provider_training(args,config) 
    if args.debug:
        accelerator.print('*********debug mode**********')
        train_df = train_df.iloc[:10000]  
        eval_df = eval_df.iloc[:1000]  
    else:
        train_df = train_df.iloc[:]
        eval_df = eval_df.iloc[:]  

    accelerator.print('train dataset number: {}'.format(len(train_df)))
    accelerator.print('eval dataset number: {}'.format(len(eval_df)))

    config['model_args'] = model_args
    trained_path = model_args['best_model_dir']
    ParrotModel = ParrotConditionPredictionModel(
        "bert",
        trained_path,
        args=model_args,
        use_cuda=True,
        cuda_device=accelerator.device
        )
    
    for param in ParrotModel.model.parameters():
            param.requires_grad = False
        
    train_examples = (
        train_df["text"].astype(str).tolist(), 
        train_df["labels"].tolist(),
    )

    if args.use_graph:
        accelerator.print('**********add graph modality**********')
        train_graph_tensor = torch.Tensor(train_df["graph_info"].tolist())
        train_dataset = ParrotModel.load_and_cache_examples(train_examples,
                                                        verbose=None,
                                                        no_cache=True,
                                                        corpus_text=train_df["corpus_text"].astype(str).tolist(),
                                                        rxn_text=train_df["text"].astype(str).tolist(),
                                                        graph=train_graph_tensor)
    else:
        train_dataset = ParrotModel.load_and_cache_examples(train_examples,
                                                        verbose=None,
                                                        no_cache=True,
                                                        corpus_text=train_df["corpus_text"].astype(str).tolist(),
                                                        rxn_text=train_df["text"].astype(str).tolist())
                                                        
    train_sampler = RandomSampler(train_dataset)
    train_loader = DataLoader(
        train_dataset,
        sampler=train_sampler,
        batch_size=args.batch_size,
        num_workers=ParrotModel.args.dataloader_num_workers,
    )

    eval_examples = (
        eval_df["text"].astype(str).tolist(),
        eval_df["labels"].tolist(),
    )
    if args.use_graph:
        eval_graph_tensor = torch.Tensor(eval_df["graph_info"].tolist())
        eval_dataset = ParrotModel.load_and_cache_examples(eval_examples,
                                                    evaluate=True,
                                                    verbose=None,
                                                    silent=True,
                                                    no_cache=False,
                                                    corpus_text=eval_df["corpus_text"].astype(str).tolist(),
                                                    rxn_text=eval_df["text"].astype(str).tolist(),
                                                    graph=eval_graph_tensor)
    else:
        eval_dataset = ParrotModel.load_and_cache_examples(eval_examples,
                                                    evaluate=True,
                                                    verbose=None,
                                                    silent=True,
                                                    no_cache=False,
                                                    corpus_text=eval_df["corpus_text"].astype(str).tolist(),
                                                    rxn_text=eval_df["text"].astype(str).tolist())


    eval_sampler = SequentialSampler(eval_dataset)
    eval_loader = DataLoader(eval_dataset,
                            sampler=eval_sampler,
                            batch_size=args.batch_size)

    args.tgt_vocab_size=config['model_args']['decoder_args']["tgt_vocab_size"]
    all_idx2data, all_data2idx=config['model_args']['decoder_args']["condition_label_mapping"]
    args.all_data2idx = all_data2idx

    model = MMRCR.Model(args).bfloat16()
    ckpt_path = os.path.join(args.checkpoints,
                        setting)
    
    if not os.path.exists(ckpt_path) and accelerator.is_local_main_process:
        os.makedirs(ckpt_path)
    if not os.path.exists(f'{ckpt_path}/last_epoch') and accelerator.is_local_main_process:
        os.makedirs(f'{ckpt_path}/last_epoch')
    if not os.path.exists(f'{ckpt_path}/best') and accelerator.is_local_main_process:
        os.makedirs(f'{ckpt_path}/best')
    
    time_now = time.time()
    train_steps = len(train_loader)
    early_stopping = EarlyStopping(accelerator=accelerator, patience=args.patience)
    trained_parameters = []
    for p in model.parameters():
        if p.requires_grad is True:
            trained_parameters.append(p)
    model_optim = optim.Adam(trained_parameters, lr=args_part.learning_rate)

    if args.lradj == 'COS':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(model_optim, T_max=20, eta_min=1e-8)
    else:
        scheduler = lr_scheduler.OneCycleLR(optimizer=model_optim,
                                            steps_per_epoch=train_steps,
                                            pct_start=args_part.pct_start,
                                            epochs=args_part.train_epochs,
                                            max_lr=args_part.learning_rate)

    loss_fn = torch.nn.CrossEntropyLoss(
        ignore_index=config['model_args']['decoder_args']['condition_label_mapping'][1]['[PAD]'])

    train_loader, eval_loader, model, model_optim, scheduler, ParrotModel= accelerator.prepare(
        train_loader, eval_loader, model, model_optim, scheduler, ParrotModel)

    if args.llm_pretrained_path is not None:
        accelerator.print(f'loading parameters from {args.llm_pretrained_path}...')
        accelerator.load_state(args.llm_pretrained_path)
    if args.use_amp:
        scaler = torch.cuda.amp.GradScaler()

    best_acc = 0.0
    for epoch in range(args.train_epochs):
        iter_count = 0
        train_loss = []

        model.train()
        epoch_time = time.time()
        logists=[]
        true_labels=[]
        for i, batch in tqdm(enumerate(train_loader)):
            inputs_parrot = ParrotModel._get_inputs_dict(batch, accelerator.device)
            if 'graph' not in inputs_parrot.keys():
                inputs_parrot['graph'] = None
            ParrotModel.model = ParrotModel.model.to(accelerator.device)
            if ParrotModel.args.fp16:
                from torch.cuda import amp
                with amp.autocast():
                    outputs = ParrotModel.model(**inputs_parrot)
            else:
                outputs = ParrotModel.model(**inputs_parrot)
            
            memory_unpool = outputs[-1]
            input_label = torch.Tensor(inputs_parrot['labels'])

            if args.use_graph:
                inputs = {
                    "input_emb": memory_unpool.to(torch.bfloat16).to(accelerator.device),
                    "labels": input_label.to(torch.bfloat16).to(accelerator.device),
                    "paragraph_text":inputs_parrot['corpus_text'],
                    "rxn_text":inputs_parrot['rxn_text'],
                    "graph":inputs_parrot['graph'].to(torch.bfloat16).to(accelerator.device)
                }
            else:
                inputs = {
                    "input_emb": memory_unpool.to(torch.bfloat16).to(accelerator.device),
                    "labels": input_label.to(torch.bfloat16).to(accelerator.device),
                    "paragraph_text":inputs_parrot['corpus_text'],
                    "rxn_text":inputs_parrot['rxn_text']
                }
            iter_count += 1
            model_optim.zero_grad()
            
            if args.use_amp:
                with torch.cuda.amp.autocast():
                    if args.output_attention:
                        outputs = model(inputs)[0]
                    else:
                        outputs = model(inputs)
                labels_out = inputs["labels"][:, 1:]
                loss = loss_fn(outputs.reshape(-1, outputs.shape[-1]).float(),
                                    labels_out.reshape(-1).long())
                train_loss.append(loss.item())
                logists.append(outputs.reshape(-1, outputs.shape[-1]).float())
                true_labels.append(labels_out.reshape(-1).long())
            else:
                if args.output_attention:
                    outputs = model(inputs)[0]
                else:
                    outputs = model(inputs)
                
                labels_out = inputs["labels"][:, 1:-1]
                loss = loss_fn(outputs.reshape(-1, outputs.shape[-1]).float(),
                                    labels_out.reshape(-1).long())
                train_loss.append(loss.item())
                logists.append(outputs.reshape(-1, outputs.shape[-1]).float())
                true_labels.append(labels_out.reshape(-1).long())

            if (i + 1) % 100 == 0:
                accelerator.print(
                    "\titers: {0}, epoch: {1} | loss: {2:.7f}".format(i + 1, epoch + 1, loss.item()))
                speed = (time.time() - time_now) / iter_count
                left_time = speed * ((args.train_epochs - epoch) * train_steps - i)
                accelerator.print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time))
                iter_count = 0
                time_now = time.time()

            if args.use_amp:
                scaler.scale(loss).backward()
                scaler.step(model_optim)
                scaler.update()
            else:
                accelerator.backward(loss)
                model_optim.step()

        scheduler.step()
        accelerator.print("lr = {:.10f}".format(model_optim.param_groups[0]['lr']),scheduler.get_last_lr()[0])  
        
        accelerator.print("Epoch: {} cost time: {}".format(epoch + 1, time.time() - epoch_time))
        train_loss = np.average(train_loss) / args.batch_size
        vali_loss, vali_accuracy= vali(args, accelerator, model, eval_loader, loss_fn, ParrotModel)
        vali_loss /= args.batch_size
        writer.add_scalars(f'Loss',
                            {
                                'Training loss': train_loss,
                                'Validation loss': vali_loss
                            },
                            epoch + 1)
        accelerator.print(
            "Epoch: {0} | Train Loss: {1:.7f} Vali Loss: {2:.7f} Accuracy: {3:.7f}".format(
                epoch + 1, train_loss, vali_loss, vali_accuracy))
        if vali_accuracy > best_acc:
            best_acc = vali_accuracy
            accelerator.print(f'Improvement on the validation accuracy')
            accelerator.print(f'Saving model to {ckpt_path}/best')
            accelerator.save_state(f'{ckpt_path}/best/')
        
    accelerator.wait_for_everyone()
    accelerator.print('Finish training!!!')
    accelerator.print(f'Saving model to {ckpt_path}/last_epoch')
    accelerator.save_state(f'{ckpt_path}/last_epoch/')
