import argparse
import glob
import os
import json
import time
import logging
import random
import re
from itertools import chain
from string import punctuation
import sys


import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl

from utils import get_dataset

from transformers import (
    AdamW,
    T5ForConditionalGeneration,
    T5Tokenizer,
    get_linear_schedule_with_warmup,
    get_cosine_schedule_with_warmup
)

class T5FineTuner(pl.LightningModule):
  def __init__(self, hparams):
    super(T5FineTuner, self).__init__()
    self.args = hparams
    
    self.model = T5ForConditionalGeneration.from_pretrained(self.args.model_name_or_path)
    self.tokenizer = T5Tokenizer.from_pretrained(self.args.tokenizer_name_or_path)
    
    for n, p in self.model.named_parameters():
      if "shared" in n or "lm_head" in n:
        p.requires_grad = True
      else:
        p.requires_grad = False
      
    
    total_params = sum(
	    param.numel() for param in self.model.parameters()
    )
    print(total_params)

    self.model.resize_token_embeddings(len(self.tokenizer))
    for p in self.model.get_input_embeddings().parameters():
      self.orig_init_emb = p

    for p in self.model.get_output_embeddings().parameters():
      self.orig_init_lin = p

    
    
    
    new_tokens = ["[S*]"]
    
    self.new_tokens = new_tokens
    self.tokenizer.add_tokens(list(new_tokens))

    inter_tokens = ["[I" + "*"*(j+2)  + "]" for j in range(10)]
    
    self.tokenizer.add_tokens(list(inter_tokens))
    self.model.resize_token_embeddings(len(self.tokenizer))


    dummy_tokens = ["[S" + "*"*(j+2)  + "]" for j in range(int(sys.argv[4]))]
    self.tokenizer.add_tokens(list(dummy_tokens))
    self.model.resize_token_embeddings(len(self.tokenizer))
    

  def is_logger(self):
    return True
  
  def forward(
      self, input_ids, attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, lm_labels=None
  ):
    
    return self.model(
        input_ids,
        attention_mask=attention_mask,
        decoder_input_ids=decoder_input_ids,
        decoder_attention_mask=decoder_attention_mask,
        labels=lm_labels,
    )

  def _step(self, batch):
    lm_labels = batch["target_ids"]
    
    lm_labels[lm_labels[:, :] == self.tokenizer.pad_token_id] = -100

    outputs = self(
        input_ids=batch["source_ids"],
        attention_mask=batch["source_mask"],
        lm_labels=lm_labels,
        decoder_attention_mask=batch['target_mask']
    )
    loss = outputs[0]

    return loss

  def training_step(self, batch, batch_idx):
    loss = self._step(batch)

    tensorboard_logs = {"train_loss": loss}
    return {"loss": loss, "log": tensorboard_logs}
  
  def training_epoch_end(self, outputs):
    avg_train_loss = torch.stack([x["loss"] for x in outputs]).mean()
    tensorboard_logs = {"avg_train_loss": avg_train_loss}

  def validation_step(self, batch, batch_idx):
    loss = self._step(batch)
    return {"val_loss": loss}
  
  def validation_epoch_end(self, outputs):
    avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
    tensorboard_logs = {"val_loss": avg_loss}
    self.log("val_loss", avg_loss)
    self.model.save_pretrained(f't5_large_{sys.argv[2]}')
    self.tokenizer.save_pretrained(f't5_large_{sys.argv[2]}')


    return {"avg_val_loss": avg_loss, "log": tensorboard_logs, 'progress_bar': tensorboard_logs}


  def configure_optimizers(self):
    "Prepare optimizer and schedule (linear warmup and decay)"

    model = self.model
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": self.args.weight_decay,
        },
        {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=self.args.learning_rate, eps=self.args.adam_epsilon)
    self.opt = optimizer
    return [optimizer]

  def optimizer_step(self,
                    epoch=None, 
                  batch_idx=None, 
                  optimizer=None, 
                  optimizer_idx=None, 
                  optimizer_closure=None, 
                  on_tpu=None, 
                  using_native_amp=None, 
                  using_lbfgs=None
                    ):
    
    if batch_idx == 0 and epoch in [0,1,2,3,4]:
      with torch.no_grad():
        train_dset = self.train_dataloader().dataset
        idx_tensor = torch.zeros(int(sys.argv[4]))
        for i in range(len(train_dset)):
          min_loss = 10000
          min_cluster = 0
          lm_labels = train_dset[i]["target_ids"].unsqueeze(0).cuda().clone()
          lm_labels[lm_labels[:, :] == self.tokenizer.pad_token_id] = -100
          input_ids = train_dset[i]["source_ids"].clone()
          attention_mask = train_dset[i]["source_mask"].clone()
          for j in range(10):
            input_ids[7] = 32101 + j  
            input_ids_tmp = input_ids.unsqueeze(0).cuda()
            attention_mask_tmp = attention_mask.unsqueeze(0).cuda()  
            outputs = self(
              input_ids=input_ids_tmp,
              attention_mask=attention_mask_tmp,
              lm_labels=lm_labels,
              decoder_attention_mask=train_dset[i]['target_mask'].unsqueeze(0).cuda()
            )
            loss = outputs[0]
            if loss < min_loss:
              min_cluster = j
              min_loss = loss
          train_dset[i]["source_ids"][7] = 32101 + min_cluster
          idx_tensor[i] = min_cluster
          
      torch.save(idx_tensor, f'idx_tensor_{sys.argv[1]}_{epoch}_{sys.argv[4]}.pt')

    if batch_idx == 0:
      self.model.save_pretrained(f't5_large_{sys.argv[2]}')
      self.tokenizer.save_pretrained(f't5_large_{sys.argv[2]}')

    optimizer.step(closure=optimizer_closure)
    optimizer.zero_grad()
    self.lr_scheduler.step()
    with torch.no_grad():
      for p in self.model.get_input_embeddings().parameters():
          p[:-(int(sys.argv[4])-11)] =  self.orig_init_emb 
      for p in self.model.get_output_embeddings().parameters():
          p[:-(int(sys.argv[4])-11)] =  self.orig_init_lin 
    
        
    
  def get_tqdm_dict(self):
    tqdm_dict = {"loss": "{:.3f}".format(self.trainer.avg_loss), "lr": self.lr_scheduler.get_last_lr()[-1]}

    return tqdm_dict

  def train_dataloader(self):
    train_dataset = get_dataset(tokenizer=self.tokenizer, type_path="train", args=self.args)
    
    dataloader = DataLoader(train_dataset, batch_size=self.args.train_batch_size, drop_last=True, shuffle=True, num_workers=4)
    t_total = (
        (len(dataloader.dataset) // (self.args.train_batch_size * max(1, self.args.n_gpu)))
        // self.args.gradient_accumulation_steps
        * float(self.args.num_train_epochs)
    )
    scheduler = get_linear_schedule_with_warmup(
        self.opt, num_warmup_steps=self.args.warmup_steps, num_training_steps=t_total
    )
    self.lr_scheduler = scheduler
    return dataloader

  def val_dataloader(self):
    val_dataset = get_dataset(tokenizer=self.tokenizer, type_path="val", args=self.args)
    return DataLoader(val_dataset, batch_size=self.args.eval_batch_size, num_workers=4)
  

logger = logging.getLogger(__name__)

class LoggingCallback(pl.Callback):
  def on_validation_end(self, trainer, pl_module):
    logger.info("***** Validation results *****")
    if pl_module.is_logger():
      metrics = trainer.callback_metrics
      # Log results
      for key in sorted(metrics):
        if key not in ["log", "progress_bar"]:
          logger.info("{} = {}\n".format(key, str(metrics[key])))

  def on_test_end(self, trainer, pl_module):
    logger.info("***** Test results *****")

    if pl_module.is_logger():
      metrics = trainer.callback_metrics

      # Log and save results to file
      output_test_results_file = os.path.join(pl_module.hparams.output_dir, "test_results.txt")
      with open(output_test_results_file, "w") as writer:
        for key in sorted(metrics):
          if key not in ["log", "progress_bar"]:
            logger.info("{} = {}\n".format(key, str(metrics[key])))
            writer.write("{} = {}\n".format(key, str(metrics[key])))


args_dict = dict(
    data_dir="", # path for data files
    output_dir="", # path to save the checkpoints
    model_name_or_path='', # path to molt5-large-caption2smiles
    tokenizer_name_or_path='', # path to molt5-large-caption2smiles
    max_seq_length=256,
    learning_rate=3e-1,
    weight_decay=0.0,
    adam_epsilon=1e-8,
    warmup_steps=0,
    train_batch_size=4,
    eval_batch_size=8,
    num_train_epochs=0,
    gradient_accumulation_steps=16,
    n_gpu=1,
    fp_16=False,
    opt_level='O1',
    max_grad_norm=1.0, 
    seed=42,
)



    


args_dict.update({'data_dir': sys.argv[1], 'output_dir': sys.argv[2], 'num_train_epochs':int(sys.argv[3])})
args = argparse.Namespace(**args_dict)

checkpoint_callback = pl.callbacks.ModelCheckpoint(
    dirpath=args.output_dir, monitor="val_loss", mode="min", save_top_k=2
)

train_params = dict(
    gpus=args.n_gpu,
    max_epochs=args.num_train_epochs,
    precision= 16 if args.fp_16 else 32,
    gradient_clip_val=args.max_grad_norm,
    callbacks=[LoggingCallback(), checkpoint_callback],
)



model = T5FineTuner(args)
trainer = pl.Trainer(**train_params)
trainer.fit(model)

model.model.save_pretrained(f't5_large_{sys.argv[2]}')
model.tokenizer.save_pretrained(f't5_large_{sys.argv[2]}')
