from transformers import PreTrainedTokenizerFast
from datasets import load_from_disk,load_metric

from copy import deepcopy
import torch
import torch.nn as nn
from transformers import T5ForConditionalGeneration
from transformers import T5Config,AutoTokenizer
from torch.utils.data import DataLoader
from transformers import AutoConfig, T5ForConditionalGeneration, AdamW, get_linear_schedule_with_warmup
import pytorch_lightning as pl
import wandb
import os

from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor

import argparse

from custom_t5_vit import CustomT5ForConditionalGeneration

import random
import numpy as np
import torch
from transformers import set_seed
from gen1M import *

from transformers import AutoModel, AutoTokenizer, T5ForConditionalGeneration
import re

from collections import deque
from functools import reduce
import pandas as pd
from tqdm import tqdm
from sklearn.metrics import accuracy_score
from nltk.translate.bleu_score import sentence_bleu

import time

# Record the start time
sbatch_start_time = time.time()


# Explicitly set the environment variable TOKENIZERS_PARALLELISM to false to disable parallelism if you expect forking to occur and want to suppress the warning.
os.environ["TOKENIZERS_PARALLELISM"] = "false"


# Offline load for Graham
bleu = load_metric("./hf_sacrebleu.py", trust_remote_code=True)
#bleu = load("sacrebleu")

# Set the float32 matrix multiplication precision to medium
# For A100
#torch.set_float32_matmul_precision('medium')

# Set the seed to ensure reproducibility
seed = 1230
set_random_seed(seed)

# Set Wandb to offline mode
os.environ["WANDB_MODE"] = "offline"


# Function to calculate total parameters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

tokenizer_name = "./tokenizer_vs22_extendarctokens"
tokenizer = AutoTokenizer.from_pretrained(f"{tokenizer_name}")

class CodeT5_smmodel(pl.LightningModule):
    def __init__(self, batch_size, config, train_dataloader_len, PE_type="RPE", rpe_type="abs", lr=5e-5, num_train_epochs=15, warmup_steps=1000):
        super().__init__()
        self.pos_enc_type = PE_type
        self.config = config
        self.rpe_type = rpe_type
        self.model = CustomT5ForConditionalGeneration(config, pos_enc_type=self.pos_enc_type, rpe_type=self.rpe_type)
        self.tokenizer = tokenizer
        self.model.resize_token_embeddings(len(self.tokenizer))
        self.batch_size = batch_size

        self.train_dataloader_len = train_dataloader_len
        self.lr = lr
        self.num_train_epochs = num_train_epochs
        self.warmup_steps = warmup_steps
        self.save_hyperparameters()


    def forward(self, input_ids, attention_mask, input_type_ids=None, output_type_ids=None, labels=None):
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels, object_idx=input_type_ids)
        return outputs

    def common_step(self, batch, batch_idx):
        outputs = self(**batch)
        loss = outputs.loss
        return loss

    def training_step(self, batch, batch_idx):
        loss = self.common_step(batch, batch_idx)
        self.log("training_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self.common_step(batch, batch_idx)
        self.log("validation_loss", loss, on_epoch=True)
        return loss

    def test_step(self, batch, batch_idx):
        loss = self.common_step(batch, batch_idx)
        return loss

    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), lr=self.hparams.lr)
        #num_train_optimization_steps = self.hparams.num_train_epochs * len(self.train_dataloader())
        num_train_optimization_steps = self.hparams.num_train_epochs * self.train_dataloader_len
        lr_scheduler = {
            'scheduler': get_linear_schedule_with_warmup(optimizer,
                                                         num_warmup_steps=self.hparams.warmup_steps,
                                                         num_training_steps=num_train_optimization_steps),
            'name': 'learning_rate',
            'interval': 'step',
            'frequency': 1
        }
        return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}

def evaluate_model(max_input_length, max_target_length, map_dataset, tokenizer, task_name, model_path, eval_model, task_id, PE_type="RPE", rpe_type="abs", force_rerun=False, batch_size=16, extra_info_cols = [], loaded_distance_matrix=None, rm_ws=False, ds_type="test"):
    # Define batch size and load model
    #model = CustomT5ForConditionalGeneration.from_pretrained(f'{model_path}/model', pos_enc_type=PE_type, rpe_type=rpe_type,distance_matrix=loaded_distance_matrix)
    model = eval_model
    print(model.name_or_path)
    print(task_name)
    print(model.pos_enc_type)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.eval()  # Ensure the model is in evaluation mode

    # Define paths for results
    results_pkl_path = os.path.join(model_path, f'{ds_type}_results.pkl')
    metrics_csv_path = os.path.join(model_path, f'{ds_type}_metrics.csv')

    # Define predefined extra information columns
    #extra_info_cols = ["task_type"]

    # Check if results file already exists and load or create a new DataFrame
    if os.path.exists(results_pkl_path) and not force_rerun:
        results_df = pd.read_pickle(results_pkl_path)
    else:
        results_df = pd.DataFrame(columns=["input", "label", "generated_output"] + extra_info_cols)

    # Perform evaluation in batches
    for start_idx in tqdm(range(0, len(map_dataset), batch_size)):
        end_idx = min(start_idx + batch_size, len(map_dataset))
        batch = map_dataset[start_idx:end_idx]

        # Tokenize inputs
        inputs = tokenizer(batch['input'], return_tensors='pt', padding="max_length", truncation=True, max_length=max_input_length).input_ids.to(device)
        labels = batch['output']
        #object_idx_b = batch['input_type_ids']
        object_idx_b = torch.tensor(batch['input_type_ids']).to(device)
        #labels = [label.replace('</s>', '').strip() for label in batch['output']]
        if rm_ws:
            labels = [re.sub(r'<s>|</s>|<pad>|\s+', '', label) for label in batch['output']]
        else:
            labels = [re.sub(r'<s>|</s>|<pad>', '', label) for label in batch['output']]

        # Generate outputs
        with torch.no_grad():
            outputs = model.generate(inputs, max_length=max_target_length,distance_matrix=loaded_distance_matrix, object_idx=object_idx_b)

        # Decode outputs
        decoded_outputs = [tokenizer.decode(output, skip_special_tokens=False) for output in outputs]
        if rm_ws:
            formatted_outputs = [re.sub(r'<s>|</s>|<pad>|\s+', '', output_str) for output_str in decoded_outputs]
        else:
            formatted_outputs = [re.sub(r'<s>|</s>|<pad>', '', output_str) for output_str in decoded_outputs]

        # Collect results
        new_entries = []
        for i in range(len(inputs)):
            result_entry = {
                "input": batch['input'][i],
                "label": labels[i],
                "generated_output": formatted_outputs[i]
            }

            # Add extra information columns if they exist
            for col in extra_info_cols:
                if col in batch:
                    result_entry[col] = batch[col][i]

            new_entries.append(result_entry)

        # Append new entries to results DataFrame
        if new_entries:
            new_df = pd.DataFrame(new_entries)
            results_df = pd.concat([results_df, new_df], ignore_index=True)

    # Print out some samples
    pd.set_option('display.max_colwidth', None)
    pd.set_option('display.max_columns', None)

    print("Sample results from the evaluation:")

    # Save results DataFrame
    results_df.to_pickle(results_pkl_path)
    # Memory issue

    # Define metrics functions
    def exact_match(y_true, y_pred):
        return y_true == y_pred

    def bleu_score(y_true, y_pred):
        return sentence_bleu([y_true.split()], y_pred.split())

    # Calculate metrics
    if os.path.exists(metrics_csv_path) and not force_rerun:
        metrics_df = pd.read_csv(metrics_csv_path)
    else:
        #metrics_df = pd.DataFrame(columns=["model_name", "exact_match", "avg_bleu", "bleu"])
        metrics_df = pd.DataFrame(columns=["model_name", "exact_match", "bleu"])

        exact_matches = [exact_match(label, output) for label, output in zip(results_df['label'], results_df['generated_output'])]
        # Calculate average metrics
        avg_exact_match = sum(exact_matches) / len(exact_matches)

        # Calculate BLEU score
        bleu_score = bleu.compute(predictions=results_df['generated_output'].tolist(), references=[[o] for o in results_df['label'].tolist()])["score"]

        # Add results to DataFrame
        new_metric_entry = pd.DataFrame([{
            #"model_name": model.name_or_path,
            "model_name": task_name,
            "exact_match": avg_exact_match,
            #"avg_bleu": avg_bleu,
            "bleu": bleu_score
        }])
        metrics_df = pd.concat([metrics_df, new_metric_entry], ignore_index=True)

        # Save metrics DataFrame
        metrics_df.to_csv(metrics_csv_path, index=False)

    print(metrics_df)

    # Save a summary DataFrame one directory above model_path
    summary_dir = os.path.abspath(os.path.join(model_path, os.pardir, os.pardir))
    summary_csv_path = os.path.join(summary_dir, f'{ds_type}_summary_metrics.csv')

    summary_data = {
        "task_id": [task_id],
        #"model_name": [model.name_or_path],
        "model_name": [task_name],
        "exact_match": [avg_exact_match],
        "bleu": [bleu_score]
    }

    summary_df = pd.DataFrame(summary_data)

    if os.path.exists(summary_csv_path):
        existing_summary_df = pd.read_csv(summary_csv_path)
        summary_df = pd.concat([existing_summary_df, summary_df], ignore_index=True)

    summary_df.to_csv(summary_csv_path, index=False)

    print("Summary metrics saved to:", summary_csv_path)


def main(input_param, tokenizer):
    print(f"Input parameter received: {input_param}")

    max_input_length = 1124
    max_target_length = 1124
    # Still OOM on P100?
    #batch_size = 16
    #For P100
    batch_size = 8

    embedding_strategies = [
        'default',                 # Standard addition of input and positional embeddings
        'hardcoded_normalization', # Normalizes both input and positional embeddings before adding
        'learnable_scaling',       # Uses a learnable scaling factor for the positional embeddings
        'weighted_sum',            # Normalizes both embeddings and then uses a weighted sum with learnable weights
        'weighted_sum_no_norm',     # Weighted sum with learnable weights, without Norm
        'learnable_scaling_vec',       # Uses a learnable scaling factor for the positional embeddings
        'weighted_sum_vec',            # Normalizes both embeddings and then uses a weighted sum with learnable weights
        'weighted_sum_no_norm_vec',     # Weighted sum with learnable weights, without Norm
        'layer_norm',              # Adds the embeddings and applies layer normalization
    ]

    #embedding_strategies_sel = embedding_strategies[input_param]
    embedding_strategies_sel = 'default'
    #embedding_strategies_sel = 'weighted_sum_no_norm_vec'

    use_objidx="yes" # yes or no

    config = T5Config(
        vocab_size=len(tokenizer),
        d_model=128,
        num_layers=3,
        num_decoder_layers=3,
        num_heads=8,
        d_ff=256,
        relative_attention_num_buckets=32,
        dropout_rate=0.1,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
        bos_token_id=tokenizer.bos_token_id,
        mask_token_id=tokenizer.mask_token_id,
        sep_token_id=tokenizer.sep_token_id,
        decoder_start_token_id=tokenizer.pad_token_id,
        PE_mix_strategy=embedding_strategies_sel,
        use_objidx=use_objidx,
        grid_max_height=33,
        grid_max_width=34,
    )

    #PE_type = "SinusoidalAPE2D"
    PE_type = "APEAlibi-duo"
    #PE_type = "SinusoidalAPE"
    #PE_type = "LearnedAPE"

    n_examples = 1000000
    n_examples_str = "1M"
    sample_size = 1000
    task_id, dataset = generate_single_dataset_hf(task_idx=input_param,path='/scratch/', seed=seed, n_examples=n_examples, testsize=sample_size)

    # Step 1: Sample 1,000 examples from the training set
    train_dataset = dataset['train']
    indices = random.sample(range(len(train_dataset)), sample_size)
    train_sample = train_dataset.select(indices)

    # Step 2: Create a deep copy of the sampled dataset
    train_sample_copy = deepcopy(train_sample)


    dataset['train'].set_format(type="torch", columns=['input_ids', 'attention_mask', 'labels', 'input_type_ids', 'output_type_ids'])
    dataset['validation'].set_format(type="torch", columns=['input_ids', 'attention_mask', 'labels', 'input_type_ids', 'output_type_ids'])

    # Do not set format on 'test' split if you don't want to change its format
    # map_dataset['test'] remains unchanged

    train_dataloader_obj = DataLoader(dataset['train'], shuffle=True, num_workers=4, batch_size=batch_size)
    valid_dataloader_obj = DataLoader(dataset['validation'], num_workers=4, batch_size=batch_size)
    test_dataloader_obj = DataLoader(dataset['test'], batch_size=batch_size)

    train_dataloader_len = len(train_dataloader_obj)

    print(f"task_id: {task_id}")
    print(f"PE_strat_idx: {input_param}")
    print(f"PE_strat: {embedding_strategies_sel}")
    print(f"max_input_length: {max_input_length}")
    print(f"max_target_length: {max_target_length}")
    print(f"batch_size: {batch_size}")

    model = CodeT5_smmodel(batch_size,config,train_dataloader_len,PE_type=PE_type)
    model.hparams.max_input_length = max_input_length

    task_name = f"ViT5-{task_id}-{PE_type}-obj_idx-ratio_211-{embedding_strategies_sel}-2Mmodel-1EP-nw4bs{batch_size}-seed{seed}-n_examples{n_examples_str}"
    project_name = "Proj"

    directory_path = f"./arc_x2y_models/{task_name}/Checkpoints/"
    os.makedirs(directory_path, exist_ok=True)


    wandb_directory_path = f"./arc_x2y_models/{task_name}/wandb/"
    os.makedirs(wandb_directory_path, exist_ok=True)

    #wandb.finish()
    #wandb_logger = WandbLogger(name=f'{task_name}', project=project_name, offline=True, save_dir=wandb_directory_path)
    #wandb_logger = WandbLogger(name=f'{task_name}', project=project_name)

    early_stop_callback = EarlyStopping(
        monitor='validation_loss',
        patience=3,
        strict=False,
        verbose=False,
        mode='min'
    )
    lr_monitor = LearningRateMonitor(logging_interval='step')

    #trainer = Trainer(accelerator="cpu",
    trainer = Trainer(accelerator="auto",
                      default_root_dir=directory_path,
                      #logger=wandb_logger,
                      max_epochs=1,
                      #max_epochs=5,
                      log_every_n_steps=5,
                      callbacks=[early_stop_callback, lr_monitor])

    # Ensure all model parameters are on CPU
    # CPU for debug msg
    #model = model.to('cpu')

    trainer.fit(model,train_dataloaders=train_dataloader_obj, val_dataloaders=valid_dataloader_obj)
    #trainer.fit(model)

    model_save_directory = f"./arc_x2y_models/{task_name}/model/"
    os.makedirs(model_save_directory, exist_ok=True)
    model.model.save_pretrained(model_save_directory)

    #wandb.finish()

    # Evaluation
    eval_model = model.model

    # Eval Test
    evaluate_model(max_input_length, max_target_length, dataset["test"], tokenizer, task_name, model_save_directory, eval_model, task_id, PE_type=PE_type, rpe_type="abs", force_rerun=True, batch_size=batch_size, rm_ws=True, ds_type="test")

    # Eval Train
    evaluate_model(max_input_length, max_target_length, train_sample_copy, tokenizer, task_name, model_save_directory, eval_model, task_id, PE_type=PE_type, rpe_type="abs", force_rerun=True, batch_size=batch_size, rm_ws=True, ds_type="train")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Process some input.")
    parser.add_argument("max_input_length", type=int, help="max_input_length, to the array_IDs")
    args = parser.parse_args()
    main(args.max_input_length, tokenizer)

    # Record the end time
    sbatch_end_time = time.time()

    # Calculate the total runtime
    sbatch_runtime = sbatch_end_time - sbatch_start_time

    # Print the runtime
    print(f"Total runtime: {sbatch_runtime} seconds")
    # Convert runtime to hours, minutes, and seconds
    hours, rem = divmod(sbatch_runtime, 3600)
    minutes, seconds = divmod(rem, 60)

    # Print the runtime
    print(f"Total runtime: {int(hours)} hrs {int(minutes)} min {seconds:.2f} seconds")

