from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, default_data_collator, get_linear_schedule_with_warmup,AutoModelForSequenceClassification
from transformers import DebertaV2Tokenizer, DebertaV2ForSequenceClassification,LlamaTokenizer, LlamaForSequenceClassification,OPTForSequenceClassification,T5Config
from peft import get_peft_config, get_peft_model, get_peft_model_state_dict, PrefixTuningConfig,LoraConfig, TaskType, PromptTuningInit, PromptTuningConfig,AdaLoraConfig,IA3Config
from datasets import load_dataset
from torch.utils.data import DataLoader,ConcatDataset,DistributedSampler,WeightedRandomSampler
from tqdm import tqdm
import torch
import torch.nn as nn
import os
import pickle
import pandas as pd
import numpy as np
import random
import evaluate
from accelerate import Accelerator, DistributedType
import gc
from itertools import cycle
import math
import matplotlib.pyplot as plt
import ray
from ray import train, tune
from ray.tune.search.optuna import OptunaSearch
import os

ray.init('auto')
###############Variables#########################
batch_size = 16
eval_batch_size = 32
max_length = 512
epochs = 150
lr = 1e-4
#lr = 1e-2
#dr= 0.01
tokenize_pre_proc=1


os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4,5,6,7"
#################Initialization#########################
seed_value = 42



print(torch.cuda.empty_cache())
print(torch.cuda.is_available())
print(torch.cuda.device_count())

np.random.seed(seed_value) # cpu vars
torch.manual_seed(seed_value) # cpu  vars
random.seed(seed_value) # Python
torch.cuda.manual_seed(seed_value)
torch.cuda.manual_seed_all(seed_value) # gpu vars
torch.backends.cudnn.deterministic = True  #needed
torch.backends.cudnn.benchmark = False

#################Dataset #########################
tokenizer =  AutoTokenizer.from_pretrained("t5-large")
special_tokens = {'additional_special_tokens': ['[TGT]', '[TGT1]', '[TGT2]']}
tokenizer.add_special_tokens(special_tokens)

glue_tasks = [ "boolq","rte","copa","multirc", "wic","wsc"]
label_column = "label"

def tokenize_function(examples, task_name):
    
    if task_name in ["boolq"]:
        inputs = [f"question: {text1} passage: {text2}" for text1, text2 in zip(examples['question'], examples['passage'])]
        targets = examples[label_column]
        #print(targets)
        targets = ['positive' if label == 1 else 'negative' for label in targets]

        #print(targets)

    elif task_name in ["rte"]:
        inputs = [f"premise: {text1} hypothesis: {text2}" for text1, text2 in zip(examples['premise'], examples['hypothesis'])]
        targets = examples[label_column]
        #print(targets)
        targets = ['negative' if label == 1 else 'positive' for label in targets]

        #print(targets)

    elif task_name in ["copa"]:
        inputs = [f"premise: {text1} question: {text2} choice1: {text3} choice2: {text4}" for text1, text2, text3, text4 in zip(examples['premise'], examples['question'],examples['choice1'],examples['choice2'])]
        targets = examples[label_column]
        #print(targets)
        targets = ['positive' if label == 1 else 'negative' for label in targets]
        #print(targets)

    elif task_name in ["multirc"]:
        inputs = [f"paragraph: {text1} question: {text2} answer: {text3}" for text1, text2,text3 in zip(examples['paragraph'], examples['question'], examples['answer'])]
        targets = examples[label_column]
        #print(targets)
        targets = ['negative' if label == 0 else 'positive' for label in targets]
        

        #print(targets)
        
    elif task_name in ["wic"]:
        inputs = []
        
        for sentence1, start1, end1, sentence2, start2, end2, word in zip(
            examples['sentence1'], examples['start1'], examples['end1'],
            examples['sentence2'], examples['start2'], examples['end2'],
            examples['word']
        ):
            sentence1_highlighted = f"{sentence1[:start1]} [TGT] {sentence1[start1:end1]} [TGT] {sentence1[end1:]}"
            sentence2_highlighted = f"{sentence2[:start2]} [TGT] {sentence2[start2:end2]} [TGT] {sentence2[end2:]}"
            
            # Combine the sentences and the word into a single input
            input_text = f"sentence1: {sentence1_highlighted} sentence2: {sentence2_highlighted} word: {word}"
            inputs.append(input_text)

        targets = examples[label_column]
        #print(targets)
        targets = ['negative' if label == 0 else 'positive' for label in targets]
        
        #print(targets)

    elif task_name in ["wsc"]:
        inputs = []
    
        for text, span1_text, span1_index, span2_text, span2_index in zip(
            examples['text'], examples['span1_text'], examples['span1_index'],
            examples['span2_text'], examples['span2_index']
        ):
            text_highlighted = (
                text[:span1_index] + "[TGT1] " + span1_text + " [TGT1] " +
                text[span1_index + len(span1_text):span2_index] + "[TGT2] " +
                span2_text + " [TGT2] " + text[span2_index + len(span2_text):]
            )
            inputs.append(f"sentence: {text_highlighted}")
            
        targets = examples[label_column]
        targets = ['negative' if label == 0 else 'positive' for label in targets]

        #print(targets)

    model_inputs = tokenizer(inputs, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt")
    labels = tokenizer(targets, max_length=2, padding="max_length", truncation=True, return_tensors="pt")
    labels = labels["input_ids"]
    labels[labels == tokenizer.pad_token_id] = -100
    model_inputs["labels"] = labels
    return model_inputs

tokenized_datasets_dict = {}
for task_name in glue_tasks:
 
    dataset = load_dataset("super_glue", task_name, trust_remote_code=True)
    dataset = dataset["train"].train_test_split(test_size=0.1)

    print(dataset)
    train_indices = random.sample(range(len(dataset["train"])), 300)
    dataset["train"] = dataset["train"].select(train_indices)

    print(dataset["train"][0])

    
    print(f"Preprocessing {task_name} task...")
    
    tokenized_datasets = dataset.map(
        lambda examples: tokenize_function(examples, task_name),
        batched=True,
        num_proc=tokenize_pre_proc,  # Adjust this for parallel processing
        remove_columns=dataset["train"].column_names,
        load_from_cache_file=False,
        desc=f"Running tokenizer on {task_name} dataset"
    )
    tokenized_datasets_dict[task_name] = tokenized_datasets
    print(f"Completed preprocessing for {task_name} task!")
    
################################tokensize_dataset#############################################################
dataloaders = {}
for task, dataset_dict in tokenized_datasets_dict.items():

    train_dataset = dataset_dict["train"]
    eval_dataset = dataset_dict["test"]

    train_dataloader = DataLoader(
        train_dataset,
        shuffle=True, 
        batch_size=batch_size, 
        collate_fn=default_data_collator,
        pin_memory=True
    )

    eval_dataloader = DataLoader(
        eval_dataset, 
        batch_size=eval_batch_size, 
        collate_fn=default_data_collator,
        pin_memory=True
    )
    dataloaders[task] = {
        "train_dataloader": train_dataloader,
        "eval_dataloader": eval_dataloader
    }
    
###########################Dataloaders##############################################
global boolq_train_dataloader1
global boolq_eval_dataloader1 

global rte_train_dataloader1
global rte_eval_dataloader1 

global multirc_train_dataloader1
global multirc_eval_dataloader1 

global copa_train_dataloader1
global copa_eval_dataloader1

global wic_train_dataloader1
global wic_eval_dataloader1 

global wsc_train_dataloader1
global wsc_eval_dataloader1 

boolq_train_dataloader1 = dataloaders["boolq"]["train_dataloader"]
boolq_eval_dataloader1 = dataloaders["boolq"]["eval_dataloader"]

rte_train_dataloader1 = dataloaders["rte"]["train_dataloader"]
rte_eval_dataloader1 = dataloaders["rte"]["eval_dataloader"]

multirc_train_dataloader1 = dataloaders["multirc"]["train_dataloader"]
multirc_eval_dataloader1 = dataloaders["multirc"]["eval_dataloader"]

copa_train_dataloader1 = dataloaders["copa"]["train_dataloader"]
copa_eval_dataloader1 = dataloaders["copa"]["eval_dataloader"]

wic_train_dataloader1 = dataloaders["wic"]["train_dataloader"]
wic_eval_dataloader1 = dataloaders["wic"]["eval_dataloader"]

wsc_train_dataloader1 = dataloaders["wsc"]["train_dataloader"]
wsc_eval_dataloader1 = dataloaders["wsc"]["eval_dataloader"]


train_dataloaders = [boolq_train_dataloader1,rte_train_dataloader1,copa_train_dataloader1,multirc_train_dataloader1,wic_train_dataloader1,wsc_train_dataloader1]
eval_dataloaders = [boolq_eval_dataloader1,rte_eval_dataloader1,copa_eval_dataloader1,multirc_eval_dataloader1,wic_eval_dataloader1,wsc_eval_dataloader1]

boolq_ratio = int(round(len(boolq_train_dataloader1)/len(copa_train_dataloader1)))
rte_ratio = int(round(len(rte_train_dataloader1)/len(copa_train_dataloader1)))
multirc_ratio = int(round(len(multirc_train_dataloader1)/len(copa_train_dataloader1)))
wic_ratio = int(round(len(wic_train_dataloader1)/len(copa_train_dataloader1)))
wsc_ratio = int(round(len(wsc_train_dataloader1)/len(wsc_train_dataloader1)))
copa_ratio=1
steps=min(len(boolq_train_dataloader1),len(rte_train_dataloader1),len(copa_train_dataloader1),len(multirc_train_dataloader1),len(wic_train_dataloader1),len(wsc_train_dataloader1))


####################################early_stopping################################################################
class EarlyStopping:
    
    def __init__(self, patience=25, delta=0, trace_func=print):
        self.patience = patience
        self.delta = delta
        self.trace_func = trace_func
        self.counter = 0
        self.best_score = None
        self.early_stop = False

    def __call__(self, average_accuracy):
        score = average_accuracy
        if self.best_score is None:
            self.best_score = score
        elif score <= self.best_score + self.delta:
            self.counter += 1
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.counter = 0

def objective(config):
    accelerator = Accelerator()
    device = accelerator.device
    #######################Model###########################
    model = AutoModelForSeq2SeqLM.from_pretrained("t5-large")

    #peft_config = LoraConfig(task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1)
    peft_config = AdaLoraConfig(
    init_r=12,
    target_r=8,
    beta1=0.85,
    beta2=0.85,
    tinit=200,
    tfinal=1000,
    deltaT=10,
    lora_alpha=32,
    lora_dropout=0.1,
    task_type=TaskType.SEQ_2_SEQ_LM,
    inference_mode=False,
)
    model = get_peft_model(model,peft_config)

    peft_config = PrefixTuningConfig(task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=False, num_virtual_tokens=config["prefix_length"],prefix_projection = True)
    #peft_config = PrefixTuningConfig(task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=False, num_virtual_tokens=20)
    model = get_peft_model(model,peft_config)

    print(model)

    for name, param in model.named_parameters():
        if param.requires_grad:
            print(f"Trainable parameter: {name}")
        else:
            print(f"Non-Trainable parameter: {name}")
            if "lora" in name:
                param.requires_grad = True

    for name, param in model.named_parameters():
        if param.requires_grad:
            print(f"Trainable parameter: {name}")


    model = model.to(device)
    model.print_trainable_parameters()

    ################################optimizer_learningrates##########################################################
    
    if config["decay_rate"] == 0:
        optimizer = torch.optim.AdamW(model.parameters(), lr=config["learning_rate"])
    else:
        optimizer = torch.optim.AdamW(model.parameters(), lr=config["learning_rate"], weight_decay=config["decay_rate"])

    lr_scheduler = get_linear_schedule_with_warmup(
        optimizer=optimizer,
        num_warmup_steps=0,
        num_training_steps=(len(copa_train_dataloader1) * epochs),
    )

    ################################Accelerator_prepare################################################################

    model, optimizer, lr_scheduler, boolq_train_dataloader,boolq_eval_dataloader,rte_train_dataloader,rte_eval_dataloader,copa_train_dataloader,copa_eval_dataloader,multirc_train_dataloader,multirc_eval_dataloader,wic_train_dataloader,wic_eval_dataloader,wsc_train_dataloader,wsc_eval_dataloader = accelerator.prepare(model, optimizer, lr_scheduler,boolq_train_dataloader1,boolq_eval_dataloader1,rte_train_dataloader1,rte_eval_dataloader1,copa_train_dataloader1,copa_eval_dataloader1,multirc_train_dataloader1,multirc_eval_dataloader1,wic_train_dataloader1,wic_eval_dataloader1,wsc_train_dataloader1,wsc_eval_dataloader1)
    early_stopping = EarlyStopping()


    #################################################Checkpoints#######################################################################

    current_dir = os.getcwd() 
    checkpoint_dir = os.path.join(current_dir, "checkpoint")  
    checkpoint_filename = "checkpoint_epoch.pth"  


    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
        print(f"Checkpoint folder created at {checkpoint_dir}")


    checkpoint_path = os.path.join(checkpoint_dir, checkpoint_filename)


    if os.path.exists(checkpoint_path):
        print(f"Checkpoint found at {checkpoint_path}. Loading model from checkpoint...")
        #model.load_state_dict(torch.load(checkpoint_path, map_location=device))
        #torch.cuda.empty_cache()
    else:
        print("No checkpoint found. Starting training from scratch.")


    #####################################################Training Loop#######################################################################


    for epoch in range(epochs):


        train_dataloader_boolq_iter = iter(boolq_train_dataloader)
        train_dataloader_rte_iter = iter(rte_train_dataloader)
        train_dataloader_copa_iter = iter(copa_train_dataloader)
        train_dataloader_multirc_iter = iter(multirc_train_dataloader)
        train_dataloader_wic_iter = iter(wic_train_dataloader)
        train_dataloader_wsc_iter = iter(wsc_train_dataloader)

        
        
        extra_boolq = len(boolq_train_dataloader) - (boolq_ratio* steps)
        extra_rte = len(rte_train_dataloader) - (rte_ratio* steps)
        extra_multirc = len(multirc_train_dataloader) - (multirc_ratio* steps)
        extra_wic = len(wic_train_dataloader) - (wic_ratio* steps)
        extra_wsc = len(wsc_train_dataloader) - (wsc_ratio* steps)

        for step in tqdm(range(steps)):

            try:
                for _ in range(boolq_ratio + (1 if extra_boolq > 0 else 0)):
                    #print("boolq")
                    batch_boolq = next(train_dataloader_boolq_iter)
                    batch = {k: v.to(device) for k, v in batch_boolq.items()}
                    outputs = model(**batch)
                    boolq_loss = outputs.loss
                    accelerator.backward(boolq_loss)
                    extra_boolq = extra_boolq-1

            except StopIteration:
                #print("boolq_pass")
                pass
                
            try:
                for _ in range(rte_ratio + (1 if extra_rte > 0 else 0)):
                    #print("rte")
                    batch_rte = next(train_dataloader_rte_iter)
                    batch = {k: v.to(device) for k, v in batch_rte.items()}
                    outputs = model(**batch)
                    rte_loss = outputs.loss
                    accelerator.backward(rte_loss)
                    extra_rte = extra_rte-1

            except StopIteration:
                #print("rte_pass")
                pass
                
            try:
                for _ in range(copa_ratio):
                    #print("copa")
                    batch_copa = next(train_dataloader_copa_iter)
                    batch = {k: v.to(device) for k, v in batch_copa.items()}
                    outputs = model(**batch)
                    copa_loss = outputs.loss
                    accelerator.backward(copa_loss)
                    #extra_copa = extra_copa-1

            except StopIteration:
                #print("copa_pass")
                pass
                
            try:
                for _ in range(multirc_ratio + (1 if extra_multirc > 0 else 0)):
                    #print("multirc")
                    batch_multirc = next(train_dataloader_multirc_iter)
                    batch = {k: v.to(device) for k, v in batch_multirc.items()}
                    outputs = model(**batch)
                    multirc_loss = outputs.loss
                    accelerator.backward(multirc_loss)
                    extra_multirc = extra_multirc-1

            except StopIteration:
                #print("multirc_pass")
                pass
                
            try:
                for _ in range(wic_ratio + (1 if extra_wic > 0 else 0)):
                    #print("wic")
                    batch_wic = next(train_dataloader_wic_iter)
                    batch = {k: v.to(device) for k, v in batch_wic.items()}
                    outputs = model(**batch)
                    wic_loss = outputs.loss
                    accelerator.backward(wic_loss)
                    extra_wic = extra_wic-1

            except StopIteration:
                #print("wic_pass")
                pass
                
            try:
                for _ in range(wsc_ratio + (1 if extra_wsc > 0 else 0)):
                    #print("qqp")
                    batch_wsc = next(train_dataloader_wsc_iter)
                    batch = {k: v.to(device) for k, v in batch_wsc.items()}
                    outputs = model(**batch)
                    wsc_loss = outputs.loss
                    accelerator.backward(wsc_loss)
                    extra_wsc = extra_wsc-1

            except StopIteration:
                #print("wsc_pass")
                pass
                
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

    ##################################################### Evaluation ###########################################################
        model.eval()
        average_accuracy=0
        total_accuracy=0

        with torch.no_grad():
            eval_loss = 0.0
            task_metric = evaluate.load("accuracy") 
            

            # Evaluate both tasks
            for task_name, dataloader in zip(glue_tasks, eval_dataloaders):
                dataloader = tqdm(dataloader, desc=f"Evaluating {task_name}", dynamic_ncols=True)
                for batch in dataloader:
                    batch = {k: v.to(device) for k, v in batch.items()}

                    # Forward pass
                    outputs = model(**batch)
                    loss = outputs.loss
                    eval_loss += loss.item()

                    predictions = outputs.logits.argmax(dim=-1)
                    references = batch["labels"]
                    predictions, references = accelerator.gather_for_metrics(
                        (predictions.to(torch.int64).view(-1), references.to(torch.int64).view(-1))
                    )

                    task_metric.add_batch(predictions=predictions.cpu().numpy(), references=references.cpu().numpy())

                eval_metric = task_metric.compute()
                print(f"Task: {task_name}, Evaluation Metric: {eval_metric}")
                with open('log.txt', 'a') as f:
                    f.write(f"Task: {task_name}, Evaluation Metric: {eval_metric}") 

                total_accuracy += eval_metric["accuracy"]


        # Compute the average accuracy across all tasks
        average_accuracy = total_accuracy / 6
        print(f"Average Accuracy across all tasks: {average_accuracy}")

        train.report({"metric": average_accuracy})

        checkpoint_epoch_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch}.pth")
        print(f"Saving checkpoint to {checkpoint_epoch_path}...")
        torch.save(model.state_dict(), checkpoint_epoch_path)
        
        with open('log.txt', 'a') as f:
            f.write(f"Average Accuracy across all tasks: TMVE={average_accuracy}") 

        torch.cuda.empty_cache()
        gc.collect()

        early_stopping(average_accuracy)
        if early_stopping.early_stop:
            print("Early stopping")
            with open('/location/log_s.txt', 'a') as f:
                f.write("\nEarly Stopping\n")
            break
  

        print("TTT")
        print(os.getpid())
        pid = os.getpid() 
        print(early_stopping.counter)
        if early_stopping.counter == 0:
            model.save_pretrained(f"/location/s_model/pid{pid}.pth")
            
            print(f"Model_of_pid={pid}_Saved")
            with open('/location/log_s.txt', 'a') as f:
                f.write("\npid= {}\n".format(pid))
                f.write("\nModel saved at epoch {}\n".format(epoch))
                f.write("\nAvg_Acc {}\n".format(average_accuracy))


    print("Training complete!")

search_space = {"learning_rate": tune.qloguniform(0.001, 0.02, 5e-5), "decay_rate": tune.choice([0, 1e-5, 1e-4, 1e-3, 1e-2]), "prefix_length": tune.randint(10, 51)}
algo = OptunaSearch()

tuner = tune.Tuner(
    tune.with_resources(objective, {"gpu": 1}),
    tune_config=tune.TuneConfig(metric="metric",mode="max",num_samples=50,search_alg=algo),
    run_config=train.RunConfig(stop={"training_iteration": 75}),
    param_space=search_space,
)
results = tuner.fit()
print("Best config is:", results.get_best_result().config)