import argparse
import torch
from accelerate import Accelerator, DeepSpeedPlugin
from accelerate import DistributedDataParallelKwargs
from torch import nn, optim
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

from models import MMRCR
from data_provider.data_factory import data_provider_testing
import time
import random
import numpy as np
import os
import yaml
import pandas as pd

from Parrot.models.parrot_model import ParrotConditionPredictionModel

os.environ['CURL_CA_BUNDLE'] = ''
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:64"

from utils.tools import EarlyStopping, vali, _load_model_args

parser = argparse.ArgumentParser(description='MM-RCR')

fix_seed = 2024
random.seed(fix_seed)
torch.manual_seed(fix_seed)
np.random.seed(fix_seed)

# basic config
parser.add_argument('--task_name', type=str, required=True, default='multi-RCR',
                    help='task name, options:[multi-RCR]')
parser.add_argument('--seed', type=int, default=2024, help='random seed')
parser.add_argument('--debug', action='store_true',
                    help='whether it is the debug mode')
parser.add_argument('--comment', type=str, default='', help='other comments that need to be recorded')
parser.add_argument('--llm_pretrained_path', type=str, default=None, help='LLM pretrained path')
parser.add_argument('--llama_pretrained_path', type=str, required=True, default=None, help='llama pretrained path')

# data loader
parser.add_argument('--data', type=str, default='USPTO-Condition', help='dataset type')
parser.add_argument('--checkpoints', type=str, default='./checkpoints/', help='location of model checkpoints')
parser.add_argument('--logs', type=str, default='./logs/', help='location of log files')
parser.add_argument('--config_path', type=str, default='./configs/uspto_condition_simcorpus.yaml', help='activation')

# forecasting task
parser.add_argument('--seq_len', type=int, default=96, help='input sequence length')
parser.add_argument('--label_len', type=int, default=48, help='start token length')
parser.add_argument('--pred_len', type=int, default=96, help='prediction sequence length')

# model define
parser.add_argument('--enc_in', type=int, default=7, help='encoder input size')
parser.add_argument('--dec_in', type=int, default=7, help='decoder input size')
parser.add_argument('--c_out', type=int, default=7, help='output size')
parser.add_argument('--d_model', type=int, default=16, help='dimension of model')
parser.add_argument('--n_heads', type=int, default=8, help='num of heads')
parser.add_argument('--e_layers', type=int, default=2, help='num of encoder layers')
parser.add_argument('--d_layers', type=int, default=1, help='num of decoder layers')
parser.add_argument('--d_ff', type=int, default=32, help='dimension of fcn')
parser.add_argument('--moving_avg', type=int, default=25, help='window size of moving average')
parser.add_argument('--factor', type=int, default=1, help='attn factor')
parser.add_argument('--dropout', type=float, default=0.1, help='dropout')
parser.add_argument('--activation', type=str, default='gelu', help='activation')
parser.add_argument('--output_attention', action='store_true', help='whether to output attention in encoder')
parser.add_argument('--patch_len', type=int, default=16, help='patch length')
parser.add_argument('--stride', type=int, default=8, help='stride')

parser.add_argument('--projection_type', type=str, choices=["reprogramming","perceiver","mlp"], default="perceiver", help='Multimodal feature fusion methods')
parser.add_argument("--rxn_text_flag", default=False, action='store_true', help='whether to use the reaction SMILES as the textual content in prompt')
parser.add_argument("--rxn_source_flag", default=False, action='store_true', help='whether to use reaction SMILES tokens as source embedding in modality projection layer')
parser.add_argument("--only_parrot", default=False, action='store_true', help='whether to only take the SMILES modality')
parser.add_argument("--only_corpus", default=False, action='store_true', help='whether to only take the similar corpus modality')
parser.add_argument("--use_graph", default=False, action='store_true', help='whether to add the graph modality')
parser.add_argument("--use_fp", default=False, action='store_true', help='whether to replace the Parrot encoder with Morgan fingerprints')

# optimization
parser.add_argument('--num_workers', type=int, default=10, help='data loader num workers')
parser.add_argument('--itr', type=int, default=1, help='experiments times')
parser.add_argument('--train_epochs', type=int, default=10, help='train epochs')
parser.add_argument('--batch_size', type=int, default=16, help='batch size of train input data')
parser.add_argument('--eval_batch_size', type=int, default=8, help='batch size of model evaluation')
parser.add_argument('--patience', type=int, default=10, help='early stopping patience')
parser.add_argument('--learning_rate', type=float, default=0.0001, help='optimizer learning rate')
parser.add_argument('--des', type=str, default='test', help='exp description')
parser.add_argument('--lradj', type=str, default='COS', help='adjust learning rate')
parser.add_argument('--pct_start', type=float, default=0.2, help='pct_start')
parser.add_argument('--use_amp', action='store_true', help='use automatic mixed precision training', default=False)
parser.add_argument('--llm_layers', type=int, default=32)


args_part = parser.parse_args()
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
deepspeed_plugin = DeepSpeedPlugin(hf_ds_config='./ds_config_zero2.json')
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs], deepspeed_plugin=deepspeed_plugin)

config = yaml.load(open(args_part.config_path, "r"),
                Loader=yaml.FullLoader)
parrot_pretrained_path=config["model_args"]['pretrained_path']
args = _load_model_args(parrot_pretrained_path)

args.update_from_dict(vars(args_part))

for ii in range(args.itr):
    # setting record of experiments
    setting = args.llm_pretrained_path.split('/')[-2]
    accelerator.print('**********settings**********')
    accelerator.print(f'rxn_text_flag: {args.rxn_text_flag}')
    accelerator.print(f'projection_type: {args.projection_type}')
    accelerator.print(f'rxn_source_flag: {args.rxn_source_flag}')
    accelerator.print(f'config_path: {args.config_path}')
    accelerator.print(f'llama_pretrained_path: {args.llama_pretrained_path}')

    save_path = f'{args.logs}/{setting}'
    writer = SummaryWriter(log_dir=save_path)

    test_df, condition_label_mapping, model_args = data_provider_testing(args,config) 
    if args.debug:
        print('*********debug mode**********')
        test_df = test_df.iloc[:1000]  
    else:
        test_df = test_df.iloc[:]  

    accelerator.print('test dataset number: {}'.format(len(test_df)))

    config['model_args'] = model_args
    trained_path = model_args['best_model_dir']
    print(f'Parrot pretrained path : {trained_path}')
    ParrotModel = ParrotConditionPredictionModel(
        "bert",
        trained_path,
        args=model_args,
        use_cuda=True,
        cuda_device=accelerator.device
        )
    
    test_examples = (
        test_df["text"].astype(str).tolist(), 
        test_df["labels"].tolist(),
    )

    if args.use_graph:
        test_graph_tensor = torch.Tensor(test_df["graph_info"].tolist())
        test_dataset = ParrotModel.load_and_cache_examples(test_examples,
                                                        verbose=None,
                                                        no_cache=True,
                                                        corpus_text=test_df["corpus_text"].astype(str).tolist(),
                                                        rxn_text=test_df["text"].astype(str).tolist(),
                                                        graph=test_graph_tensor)
    else:
        test_dataset = ParrotModel.load_and_cache_examples(test_examples,
                                                        verbose=None,
                                                        no_cache=True,
                                                        corpus_text=test_df["corpus_text"].astype(str).tolist(),
                                                        rxn_text=test_df["text"].astype(str).tolist())
    test_sampler = SequentialSampler(test_dataset)
    test_loader = DataLoader(test_dataset,
                            sampler=test_sampler,
                            batch_size=args.batch_size)

    args.tgt_vocab_size=config['model_args']['decoder_args']["tgt_vocab_size"]
    all_idx2data, all_data2idx=config['model_args']['decoder_args']["condition_label_mapping"]
    args.all_data2idx = all_data2idx
    model = MMRCR.Model(args).bfloat16()
    
    time_now = time.time()
    early_stopping = EarlyStopping(accelerator=accelerator, patience=args.patience)
    trained_parameters = []
    for p in model.parameters():
        if p.requires_grad is True:
            trained_parameters.append(p)
    model_optim = optim.Adam(trained_parameters, lr=args_part.learning_rate)

    if args.lradj == 'COS':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(model_optim, T_max=20, eta_min=1e-8)
    else:
        scheduler = lr_scheduler.OneCycleLR(optimizer=model_optim,
                                            steps_per_epoch=test_steps,
                                            pct_start=args_part.pct_start,
                                            epochs=args_part.train_epochs,
                                            max_lr=args_part.learning_rate)


    test_loader, model, model_optim, scheduler, ParrotModel= accelerator.prepare(
        test_loader, model, model_optim, scheduler, ParrotModel)

    if args.llm_pretrained_path is not None:
        accelerator.print(f'loading parameters from {args.llm_pretrained_path}...')
        accelerator.load_state(args.llm_pretrained_path)
    if args.use_amp:
        scaler = torch.cuda.amp.GradScaler()

    best_acc = 0.0
    eval_data = []
    top_k = max((20,))
    if args.data == 'USPTO_500MT_Condition':
        condition_n = 6
    elif args.data == 'USPTO-Condition':
        condition_n = 5

    logists=[]
    true_labels=[]
    model.eval()

    with torch.no_grad():
        for i, batch in tqdm(enumerate(test_loader)):
            inputs_parrot = ParrotModel._get_inputs_dict(batch, accelerator.device)
            if 'graph' not in inputs_parrot.keys():
                inputs_parrot['graph'] = None
            ParrotModel.model = ParrotModel.model.to(accelerator.device)

            if ParrotModel.args.fp16:
                from torch.cuda import amp
                with amp.autocast():
                    outputs = ParrotModel.model(**inputs_parrot)
            else:
                outputs = ParrotModel.model(**inputs_parrot)
            memory_unpool = outputs[-1]

            input_label = torch.Tensor(inputs_parrot['labels'].tolist())
            if args.use_graph:
                inputs = {
                    "input_emb": memory_unpool.to(torch.bfloat16).to(accelerator.device),
                    "labels": input_label.to(torch.bfloat16).to(accelerator.device),
                    "paragraph_text":inputs_parrot['corpus_text'],
                    "rxn_text":inputs_parrot['rxn_text'],
                    "graph":inputs_parrot['graph'].to(torch.bfloat16).to(accelerator.device)
                }
            else:
                inputs = {
                    "input_emb": memory_unpool.to(torch.bfloat16).to(accelerator.device),
                    "labels": input_label.to(torch.bfloat16).to(accelerator.device),
                    "paragraph_text":inputs_parrot['corpus_text'],
                    "rxn_text":inputs_parrot['rxn_text']
                }

            if args.use_amp:
                with torch.cuda.amp.autocast():
                    if args.output_attention:
                        outputs = model(inputs)[0]
                    else:
                        outputs = model(inputs)
            else:
                if args.output_attention:
                    outputs = model(inputs)[0]
                else:
                    outputs = model(inputs)

            labels_out = inputs["labels"][:, 1:-1]
            logists.append(outputs.reshape(-1, outputs.shape[-1]).float())
            true_labels.append(labels_out.reshape(-1).long())

            for bn in range(outputs.shape[0]):
                cur_eval_data = []
                cur_eval_data.append(inputs_parrot['rxn_text'][bn])
                for cn in range(condition_n):
                    cur_eval_data.append(outputs[bn][cn].cpu().tolist())
                    cur_eval_data.append(labels_out[bn][cn].cpu().tolist())
                eval_data.append(cur_eval_data)

    if args.data == 'USPTO_500MT_Condition':
        eval_data = pd.DataFrame(eval_data, columns=['reaction', 
                                                    'pred_r1', 'true_r1',
                                                    'pred_r2', 'true_r2',
                                                    'pred_r3', 'true_r3',
                                                    'pred_r4', 'true_r4',
                                                    'pred_r5', 'true_r5',
                                                    'pred_r6', 'true_r6'])
    elif args.data == 'USPTO-Condition':
        eval_data = pd.DataFrame(eval_data, columns=['reaction', 
                                                    'pred_c1', 'true_c1',
                                                    'pred_s1', 'true_s1',
                                                    'pred_s2', 'true_s2',
                                                    'pred_r1', 'true_r1',
                                                    'pred_r2', 'true_r2'])


    predict_labels=torch.argmax(torch.cat(logists,dim=0),dim=1)
    true_labels=torch.cat(true_labels,dim=0)
    assert predict_labels.shape==true_labels.shape
    correct = (predict_labels == true_labels).sum().item()
    accuracy = correct / len(true_labels)
    accelerator.print(
        "Avg Test Accuracy: {0:.7f} ".format(accuracy),  f'for totally {predict_labels.shape} conditions')

    save_path = f'./evaluation_results/{setting}/results{args.comment}'
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    save_idx = str(accelerator.device).split(':')[-1]
    print(f'saving evaluation result to {save_path}/test_results_device{save_idx}.csv')
    eval_data.to_csv(f'{save_path}/test_results_device{save_idx}.csv')
