import os
import argparse
from tqdm import tqdm
import numpy as np
import pandas as pd

import torch
from torch.utils.tensorboard import SummaryWriter
from rdkit import Chem, DataStructs

from model import GPT, GPTConfig
from vocabulary import read_vocabulary
from utils import sample_SMILES, likelihood, to_tensor, calc_fingerprints
from scoring_function import get_scores, int_div


def memory_update(memory, smiles, scores, seqs, memory_size, replay):
    """Independent function for memory update."""
    scores = list(scores)
    seqs_list = [seqs[i, :].cpu().numpy() for i in range(len(smiles))]

    for i in range(len(smiles)):
        if scores[i] < 0:
            continue
        fp, smiles_i = calc_fingerprints([smiles[i]])
        new_data = pd.DataFrame({
            "smiles": smiles_i[0],
            "scores": scores[i],
            "seqs": [seqs_list[i]],
            "fps": fp[0]
        })
        memory = pd.concat([memory, new_data], ignore_index=True, sort=False)

    memory = memory.drop_duplicates(subset=["smiles"])
    memory = memory.sort_values('scores', ascending=False).reset_index(drop=True)
    if len(memory) > memory_size:
        memory = memory.head(memory_size)

    # Experience replay
    if replay > 0:
        s = min(len(memory), replay)
        experience = memory.head(5 * replay).sample(s)
        experience = experience.reset_index(drop=True)
        smiles += list(experience["smiles"])
        scores += list(experience["scores"])
        for index in experience.index:
            seqs = torch.cat(
                (seqs, torch.tensor(experience.loc[index, "seqs"],
                                    dtype=torch.long).view(1, -1).cuda()), dim=0
            )

    return memory, smiles, np.array(scores), seqs


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--run_name', type=str, default="")
    parser.add_argument('--model_type', type=str, default="gpt")
    parser.add_argument('--device', type=str, default="cuda")
    parser.add_argument('--oracle', type=str, default="JNK3")
    parser.add_argument('--n_layer', type=int, default=12)
    parser.add_argument('--n_head', type=int, default=12)
    parser.add_argument('--n_embd', type=int, default=384)    
    parser.add_argument('--max_length', type=int, default=128)
    parser.add_argument('--n_steps', type=int, default=1000)
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--sigma', type=float, default=100)
    parser.add_argument('--learning_rate', type=float, default=2e-5)
    parser.add_argument('--temperature', type=float, default=1.0)
    parser.add_argument('--memory_size', type=int, default=1000)
    parser.add_argument('--replay', type=int, default=0)
    parser.add_argument('--prior_path', type=str, default="ckpt/final.pt")
    parser.add_argument('--vocab_path', type=str, default="data/vocab.txt")
    parser.add_argument('--output_dir', type=str, default="rl_log/")
    args = parser.parse_args()
    print(args)

    writer = SummaryWriter(args.output_dir + f"{args.oracle}_{args.run_name}/")
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    writer.add_text("configs", str(args))

    # Load vocabulary
    voc = read_vocabulary(args.vocab_path)

    # Build model config
    model_config = GPTConfig(
        voc.__len__(),
        n_layer=args.n_layer,
        n_head=args.n_head,
        n_embd=args.n_embd,
        block_size=args.max_length
    )

    # Setup models
    if args.model_type == "gpt":
        prior = GPT(model_config).to(args.device)
        agent = GPT(model_config).to(args.device)
        optimizer = agent.configure_optimizers(
            weight_decay=0.1,
            learning_rate=args.learning_rate,
            betas=(0.9, 0.95)
        )

    # Load pretrained prior
    prior.load_state_dict(torch.load(args.prior_path), strict=True)
    for param in prior.parameters():
        param.requires_grad = False
    prior.eval()

    # Load agent (trainable)
    agent.load_state_dict(torch.load(args.prior_path), strict=True)
    agent.eval() 

    # Initialize memory
    memory = pd.DataFrame(columns=["smiles", "scores", "seqs", "fps"])
    if not os.path.exists(f'rl_outputs/{args.oracle}_{args.run_name}/'):
        os.makedirs(f'rl_outputs/{args.oracle}_{args.run_name}/')
    if not os.path.exists(f'rl_ckpts/{args.oracle}_{args.run_name}/'):
        os.makedirs(f'rl_ckpts/{args.oracle}_{args.run_name}/')

    # Training loop
    for step in tqdm(range(args.n_steps)):
        samples, seqs, entropies = sample_SMILES(agent, voc, n_mols=args.batch_size, temperature=args.temperature)
        scores = get_scores(samples, mode=args.oracle)
        
        writer.add_scalar('Entropy', np.mean(entropies.detach().cpu().numpy()), step)
        writer.add_scalar('Step Mean', np.mean(np.array(scores)), step)
        if (step + 1) % 10 == 0:
            smiles_df = pd.DataFrame(samples, columns=["SMILES"])
            smiles_df.to_csv(f'rl_outputs/{args.oracle}_{args.run_name}/smiles_step{step+1}.csv', index=False)
            writer.add_scalar('Step Div', int_div(samples), step)
            torch.save(agent.state_dict(), f'rl_ckpts/{args.oracle}_{args.run_name}/agent_step{step+1}.pt')

        memory, samples, scores, seqs = memory_update(
            memory, samples, scores, seqs, args.memory_size, args.replay
        )

        prior_likelihood = likelihood(prior, seqs)
        agent_likelihood = likelihood(agent, seqs)
        loss = torch.pow(args.sigma * to_tensor(np.array(scores)) -
                         (prior_likelihood - agent_likelihood), 2).mean()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Logging
        writer.add_scalar('Memory Mean', np.mean(np.array(memory["scores"])), step)
        writer.add_scalar('Prior Likelihood', np.mean(prior_likelihood.detach().cpu().numpy()), step)
        writer.add_scalar('Agent Likelihood', np.mean(agent_likelihood.detach().cpu().numpy()), step)

        writer.add_scalar('Top-1', memory["scores"][0], step)
        writer.add_scalar('Top-10 Mean', np.mean(np.array(memory["scores"][:10])), step)
        writer.add_scalar('Top-100 Mean', np.mean(np.array(memory["scores"][:100])), step)

        if (step + 1) % 10 == 0:
            writer.add_scalar('Top-100 Div', int_div(list(memory["smiles"][:100])), step)

        if (step + 1) % 100 == 0:
            memory.to_csv(f'rl_outputs/{args.oracle}_{args.run_name}/memory_step{step+1}.csv')

    # Final save
    memory.to_csv(f'rl_outputs/{args.oracle}_{args.run_name}/final_{args.n_steps}steps.csv')
    torch.save(agent.state_dict(), f'rl_ckpts/{args.oracle}_{args.run_name}_finalagent.pt')

    print(f'top-1 score: {memory["scores"][0]}')
    print(f'top-10 score: {np.mean(np.array(memory["scores"][:10]))}')
    print(f'top-100 score: {np.mean(np.array(memory["scores"][:100]))}, diversity: {int_div(list(memory["smiles"][:100]))}')

    writer.close()
