import os
# os.environ['HF_HOME'] = "checkpoints"
# os.environ['HTTP_PROXY'] = 'http://127.0.0.1:8118'
# os.environ['HTTPS_PROXY'] = 'http://127.0.0.1:8118'
from torch.utils.data import DataLoader
from transformers import T5Tokenizer, T5ForConditionalGeneration, get_scheduler
from datasets import load_dataset
import torch

from datasets import Dataset
from torch.utils.data import TensorDataset, random_split
from torch.utils.data import DataLoader, SequentialSampler
ava = ['cola','rte','sst2','qnli','qqp','wnli','mnli',]

def getDataLoader_cola(args,input,tokenizer,name):
    inputs = ["cola sentence: " + doc1 for doc1 in input['sentence']]
    tokenized_inputs = tokenizer(inputs, padding = True, truncation = True, return_tensors="pt")
    source_ids = tokenized_inputs['input_ids']
    source_mask = tokenized_inputs['attention_mask']
    
    targets = [('acceptable' if tag else 'unacceptable') for tag in input['label']]
    tokenized_outputs = tokenizer(targets, padding = True, return_tensors="pt")
    target_ids = tokenized_outputs['input_ids']
    target_mask = tokenized_outputs['attention_mask']

    batch_size = args.batch
    data = TensorDataset(source_ids, source_mask, target_ids, target_mask)
    sampler = SequentialSampler(data)
    args.inputs[name], args.targets[name] = inputs, targets
    return DataLoader(data, sampler = sampler, batch_size=batch_size)

def getDataLoader_rte(args,input, tokenizer,name):
    inputs = ["rte sentence1: " + doc1 + " sentence2: "+ doc2 for doc1, doc2 in zip(input['sentence1'], input['sentence2'])]
    tokenized_inputs = tokenizer(inputs, padding = True, truncation = True, return_tensors="pt")
    source_ids = tokenized_inputs['input_ids']
    source_mask = tokenized_inputs['attention_mask']

    targets = ['entailment' if label == 0 else 'not_entailment' for label in input['label']]
    tokenized_outputs = tokenizer(targets, padding = True, return_tensors="pt")
    target_ids = tokenized_outputs['input_ids']
    target_mask = tokenized_outputs['attention_mask']
    
    batch_size = args.batch
    data = TensorDataset(source_ids, source_mask, target_ids, target_mask)
    sampler = SequentialSampler(data)
    args.inputs[name], args.targets[name] = inputs, targets
    return DataLoader(data, sampler = sampler, batch_size=batch_size)

def getDataLoader_mrpc(args,input, tokenizer,name):
    inputs = ["mrpc sentence1: " + doc1 + " sentence2: "+ doc2 for doc1, doc2 in zip(input['sentence1'], input['sentence2'])]
    tokenized_inputs = tokenizer(inputs, padding = True, truncation = True, return_tensors="pt")
    source_ids = tokenized_inputs['input_ids']
    source_mask = tokenized_inputs['attention_mask']
    
    targets = [('equivalent' if label==1 else 'not_equivalent') for label in input['label']]
    tokenized_outputs = tokenizer(targets, padding = True, return_tensors="pt")
    target_ids = tokenized_outputs['input_ids']
    target_mask = tokenized_outputs['attention_mask']
    
    batch_size = args.batch
    data = TensorDataset(source_ids, source_mask, target_ids, target_mask)
    sampler = SequentialSampler(data)
    args.inputs[name], args.targets[name] = inputs, targets
    return DataLoader(data, sampler = sampler, batch_size=batch_size)

def getDataLoader_mnli(args,input, tokenizer,name):
    inputs = ["mnli premise: " + doc1 + " hypothesis: "+ doc2 for doc1, doc2 in zip(input['premise'], input['hypothesis'])]
    tokenized_inputs = tokenizer(inputs, padding = True, truncation = True, return_tensors="pt")
    source_ids = tokenized_inputs['input_ids']
    source_mask = tokenized_inputs['attention_mask']
    def modify_mnli(tag):
        if tag==0:
            return 'entailment'
        elif tag==1:
            return 'neutral'
        else:
            return 'contradiction'
    targets = [modify_mnli(label) for label in input['label']]
    tokenized_outputs = tokenizer(targets, padding = True, return_tensors="pt")
    target_ids = tokenized_outputs['input_ids']
    target_mask = tokenized_outputs['attention_mask']
    
    batch_size = args.batch
    data = TensorDataset(source_ids, source_mask, target_ids, target_mask)
    sampler = SequentialSampler(data)
    args.inputs[name], args.targets[name] = inputs, targets
    return DataLoader(data, sampler = sampler, batch_size=batch_size)
    
def getDataLoader_sst2(args,input, tokenizer,name):
    inputs = ["sst2 sentence: " + doc1 for doc1 in input['sentence']]
    tokenized_inputs = tokenizer(inputs, padding = True, truncation = True, return_tensors="pt")
    source_ids = tokenized_inputs['input_ids']
    source_mask = tokenized_inputs['attention_mask']

    targets = ["negative" if label == 0 else "positive" for label in input['label']]
    tokenized_outputs = tokenizer(targets, padding = True, return_tensors="pt")
    target_ids = tokenized_outputs['input_ids']
    target_mask = tokenized_outputs['attention_mask']
    
    batch_size = args.batch
    data = TensorDataset(source_ids, source_mask, target_ids, target_mask)
    sampler = SequentialSampler(data)
    args.inputs[name], args.targets[name] = inputs, targets
    return DataLoader(data, sampler = sampler, batch_size=batch_size)

def getDataLoader_qnli(args,input, tokenizer,name):
    inputs = ["qnli question: "+doc1+" sentence: "+doc2 for doc1, doc2 in zip(input['question'], input['sentence'])]
    tokenized_inputs = tokenizer(inputs, padding = True, truncation = True, return_tensors="pt")
    source_ids = tokenized_inputs['input_ids']
    source_mask = tokenized_inputs['attention_mask']

    targets = ['entailment' if label == 0 else 'not_entailment' for label in input['label']]
    tokenized_outputs = tokenizer(targets, padding = True, return_tensors="pt")
    target_ids = tokenized_outputs['input_ids']
    target_mask = tokenized_outputs['attention_mask']
    
    batch_size = args.batch
    data = TensorDataset(source_ids, source_mask, target_ids, target_mask)
    sampler = SequentialSampler(data)
    args.inputs[name], args.targets[name] = inputs, targets
    return DataLoader(data, sampler = sampler, batch_size=batch_size)

def getDataLoader_qqp(args,input, tokenizer,name):
    inputs = ["qqp question1: "+doc1+" question2: "+doc2 for doc1, doc2 in zip(input['question1'], input['question2'])]
    tokenized_inputs = tokenizer(inputs, padding = True, truncation = True, return_tensors="pt")
    source_ids = tokenized_inputs['input_ids']
    source_mask = tokenized_inputs['attention_mask']

    targets = ["duplicate" if label == 1 else "not_duplicate" for label in input['label']]
    tokenized_outputs = tokenizer(targets, padding = True, return_tensors="pt")
    target_ids = tokenized_outputs['input_ids']
    target_mask = tokenized_outputs['attention_mask']
    
    batch_size = args.batch
    data = TensorDataset(source_ids, source_mask, target_ids, target_mask)
    sampler = SequentialSampler(data)
    args.inputs[name], args.targets[name] = inputs, targets
    return DataLoader(data, sampler = sampler, batch_size=batch_size)


def getDataLoader_wnli(args,input, tokenizer,name):
    inputs = ["wnli sentence1: " + doc1 + " sentence2: "+ doc2 for doc1, doc2 in zip(input['sentence1'], input['sentence2'])]
    tokenized_inputs = tokenizer(inputs, padding = True, truncation = True, return_tensors="pt")
    source_ids = tokenized_inputs['input_ids']
    source_mask = tokenized_inputs['attention_mask']

    targets = ['entailment' if label == 1 else 'not_entailment' for label in input['label']]
    tokenized_outputs = tokenizer(targets, padding = True, return_tensors="pt")
    target_ids = tokenized_outputs['input_ids']
    target_mask = tokenized_outputs['attention_mask']
    
    batch_size = args.batch
    data = TensorDataset(source_ids, source_mask, target_ids, target_mask)
    sampler = SequentialSampler(data)
    args.inputs[name], args.targets[name] = inputs, targets
    return DataLoader(data, sampler = sampler, batch_size=batch_size)


def get_glue_tasks(args, fixed_old=False): 
    # assert 0,'lr and epoch not designated'
    args.inputs,args.targets = {},{}
    if args.dataset == 'cola':
        getDataLoader = getDataLoader_cola
        splits = ['test','train','validation']
        
        # args.batch, args.epoch = 128, 32
        args.lr = 3e-4
    elif args.dataset == 'mnli':
        getDataLoader = getDataLoader_mnli
        splits = ['test_matched','train','validation_matched']
        # args.batch, args.epoch = 16, 8
        args.lr = 3e-4
    elif args.dataset == 'mrpc':
        getDataLoader = getDataLoader_mrpc
        splits = ['test','train','validation']
        # args.batch, args.epoch = 32, 32
        args.lr = 3e-4
    elif args.dataset == 'rte':
        getDataLoader = getDataLoader_rte
        splits = ['test','train','validation']
        # args.batch, args.epoch = 16, 32
        args.lr = 3e-4
    elif args.dataset == 'sst2':
        getDataLoader = getDataLoader_sst2
        splits = ['test','train','validation']
        # args.batch, args.epoch = 64, 16
        args.lr = 3e-4
    elif args.dataset == 'qnli':
        getDataLoader = getDataLoader_qnli
        splits = ['test','train','validation']
        # args.batch, args.epoch = 32, 16
        args.lr = 3e-4
        # about 1h per epoch
    elif args.dataset == 'qqp':
        getDataLoader = getDataLoader_qqp
        splits = ['test','train','validation']
        # args.batch, args.epoch = 32, 16
        args.lr = 3e-4
        # about 1h per epoch
    elif args.dataset == 'wnli':
        getDataLoader = getDataLoader_wnli
        splits = ['test','train','validation']
        # args.batch, args.epoch = 32, 16
        args.lr = 3e-4
        # about 1h per epoch
    else:
        raise Exception(f'{args.dataset} dataset is not defined!')
    
    dataset_test = load_dataset("nyu-mll/glue", args.dataset, split=splits[0])
    dataset_train = load_dataset("nyu-mll/glue", args.dataset, split=splits[1])
    dataset_validation = load_dataset("nyu-mll/glue", args.dataset, split=splits[2])
    
    try:
        assert args.tokenizer_model 
    except:
        args.tokenizer_model = args.base_model
        
        
    tokenizer = T5Tokenizer.from_pretrained(args.tokenizer_model)
    dl_train, dl_val = getDataLoader(args,dataset_train,tokenizer,'train'),\
        getDataLoader(args,dataset_validation,tokenizer,'validation')
    return dl_train, dl_val
    
    
def for_Trainer(dl):
    dr = {
        'input_ids':[],
        'attention_mask':[],
        'labels':[],
        'decoder_attention_mask':[],
    }
    for id,_batch in enumerate(dl):
        source_ids, source_mask, lm_labels, target_mask = _batch
        lm_labels[lm_labels[:, :] == 0] = -100
        dr['input_ids'] += [source_ids]
        dr['attention_mask'] += [source_mask]
        dr['labels'] += [lm_labels]
        dr['decoder_attention_mask'] += [target_mask]
    for key in dr.keys():
        dr[key] = torch.concat(dr[key],axis=0)
    return Dataset.from_dict(dr)
    


# def add_glue_tasks(args):
#     for dataset in args.datasets:
#         args.dataset = dataset
#         train_dataloader, test_dataloader, validation_dataloader = glue_tasks(args)
#         args.train_dataloaders += [train_dataloader]
#         args.test_dataloaders += [test_dataloader]
#         args.validation_dataloaders += [validation_dataloader]
#         args.epochs += [args.epoch]
#         args.batchs += [args.batch]