'''
Infer models to obtain their logits and accuracies and create model data
from which baseline Pareto front and cascades can be constructed.
'''

import json
from math import ceil
import os

from datasets import load_dataset
from fvcore.nn import FlopCountAnalysis, parameter_count
import pandas as pd
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForSequenceClassification


checkpoints_sst2 = ['distilbert-base-uncased-finetuned-sst-2-english',
                    'textattack/roberta-base-SST-2',
                    'textattack/bert-base-uncased-SST-2',
                    'philschmid/MiniLM-L6-H384-uncased-sst2',
                    'textattack/albert-base-v2-SST-2',
                    'textattack/xlnet-base-cased-SST-2',
                    'howey/roberta-large-sst2',
                    'howey/bert-base-uncased-sst2',
                    'M-FAC/bert-tiny-finetuned-sst2',
                    'yoshitomo-matsubara/bert-large-uncased-sst2',
                    'echarlaix/bert-base-uncased-sst2-acc91.1-d37-hybrid',
                    'howey/electra-base-sst2',
                    'howey/electra-large-sst2',
                    'M-FAC/bert-mini-finetuned-sst2',
                    'mrm8488/deberta-v3-small-finetuned-sst2',
                    'textattack/distilbert-base-cased-SST-2',
                    'huawei-noah/DynaBERT_SST-2',
                    'Blaine-Mason/hackMIT-finetuned-sst2',
                    ]

checkpoints_mrpc = ['howey/bert-base-uncased-mrpc',
                    'howey/electra-base-mrpc',
                    'howey/electra-large-mrpc',
                    'howey/roberta-large-mrpc',
                    'textattack/xlnet-base-cased-MRPC',
                    'textattack/xlnet-large-cased-MRPC',
                    'textattack/bert-base-uncased-MRPC',
                    'textattack/roberta-base-MRPC',
                    'textattack/albert-base-v2-MRPC',
                    'textattack/distilbert-base-uncased-MRPC',
                    'textattack/distilbert-base-cased-MRPC',
                    'M-FAC/bert-tiny-finetuned-mrpc',
                    'M-FAC/bert-mini-finetuned-mrpc',
                    'mrm8488/deberta-v3-small-finetuned-mrpc',
                    'yoshitomo-matsubara/bert-base-uncased-mrpc_from_bert-large-uncased-mrpc',
                    'yoshitomo-matsubara/bert-large-uncased-mrpc'               
                    ]

checkpoints_qqp = ['howey/electra-small-qqp',
                   'howey/electra-base-qqp',
                   'howey/electra-large-qqp',
                   'howey/roberta-large-qqp',
                   'howey/bert-base-uncased-qqp',
                   'M-FAC/bert-tiny-finetuned-qqp',
                   'M-FAC/bert-mini-finetuned-qqp',
                   'yoshitomo-matsubara/bert-base-uncased-qqp_from_bert-large-uncased-qqp',
                   'yoshitomo-matsubara/bert-large-uncased-qqp',
                   'textattack/xlnet-base-cased-QQP',
                   'textattack/distilbert-base-cased-QQP',
                   'textattack/albert-base-v2-QQP'
                   ]

checkpoints_qnli = ['M-FAC/bert-tiny-finetuned-qnli',
                    'M-FAC/bert-mini-finetuned-qnli',
                    'howey/bert-base-uncased-qnli',
                    'gchhablani/bert-base-cased-finetuned-qnli',
                    'yoshitomo-matsubara/bert-base-uncased-qnli_from_bert-large-uncased-qnli',
                    'yoshitomo-matsubara/bert-large-uncased-qnli',
                    'Alireza1044/albert-base-v2-qnli',
                    'anirudh21/albert-large-v2-finetuned-qnli',
                    'textattack/distilbert-base-uncased-QNLI',
                    'mrm8488/deberta-v3-small-finetuned-qnli',
                    'howey/electra-base-qnli',
                    'howey/electra-large-qnli',
                    'textattack/roberta-base-QNLI',
                    'howey/roberta-large-qnli',
                    'textattack/xlnet-base-cased-QNLI',
                    ]


# obtain logits and return: [name, top1,[ f1,] mac, param, length]
@torch.no_grad()
def infer(checkpoint, dataset_name, path_logits, getf1 = False, device = 0):
    dataset = load_dataset('glue', dataset_name, cache_dir='data', split='validation')
    keys = list(dataset.features.keys())
    l = len(dataset)
    name = checkpoint.replace('/','-')
    tokenizer = AutoTokenizer.from_pretrained(checkpoint)
    model = AutoModelForSequenceClassification.from_pretrained(checkpoint).to(device).eval()
    logits = []
    for k in range(ceil(l/100)):
        if dataset_name in ['sst2']:
            inputs = tokenizer(dataset[keys[0]][k*100:(k+1)*100], padding=True, truncation=True, return_tensors='pt')
        elif dataset_name in ['mrpc','qnli','qqp']:
            inputs = tokenizer([[i, j] for i,j in zip(dataset[keys[0]][k*100:(k+1)*100],dataset[keys[1]][k*100:(k+1)*100])], padding=True, truncation=True, return_tensors='pt')
        logits.append(model(**inputs.to(device)).logits.to('cpu'))
    logits = torch.cat(logits,0)
    torch.save(logits, path_logits+name+'.pt')
    labels = torch.tensor(dataset['label'])
    pred = torch.argmax(logits, 1)
    corr = (pred == labels).tolist()
    pred = pred.tolist()
    top1 = 100*sum(corr)/l
    if dataset_name == 'sst2':
        inputs = tokenizer(dataset[keys[0]], padding=True, truncation=True, return_tensors='pt')
    elif dataset_name in ['mrpc','qnli','qqp']:
        inputs = tokenizer([[i, j] for i,j in zip(dataset[keys[0]],dataset[keys[1]])], padding=True, truncation=True, return_tensors='pt')
    length = len(inputs['input_ids'][0])
    param = parameter_count(model)['']
    flops = FlopCountAnalysis(model, torch.unsqueeze(inputs['input_ids'][0], 0).to(device))
    mac = flops.total()
    inference_data(name, labels, logits, 'data/infer/infer_')
    if getf1:
        tp = sum([1 for i in range(l) if corr[i] == 1 and pred[i] == 1])
        fp = sum([1 for i in range(l) if corr[i] == 0 and pred[i] == 1])
        fn = sum([1 for i in range(l) if corr[i] == 0 and pred[i] == 0])
        f1 = 100*(tp / (tp + (fp+fn)/2))
        print(name,'top1',top1,'f1',f1,'mac',mac,'param',param,'length',length)
        return [name, top1, f1, mac, param, length]
    else:
        print(name,'top1',top1,'mac',mac,'param',param,'length',length)
        return [name, top1, mac, param, length]


# obtain nested list with model inference data
def inference_data(model, labels, logits, path_infer):
    if type(logits) == str:
        logits = torch.load(logits+model+'.pt')
    predicted = torch.argmax(logits, 1)
    ea = [(predicted == labels).tolist(),
          predicted.tolist(),
          torch.max(F.softmax(logits, dim=1), 1)[0].tolist()]
    with open(path_infer+model+'.txt', 'w') as f: json.dump(ea,f,indent=2)


def main():
    if not os.path.exists('data/logits'):
        os.makedirs('data/logits')
    if not os.path.exists('data/infer'):
        os.makedirs('data/infer')
    
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    
    # sst2
    data = []
    for n in checkpoints_sst2:
        data.append(infer(n, 'sst2', 'data/logits/logits_', device = device))
    df_sst2 = pd.DataFrame(data, columns=['model', 'top1', 'mac', 'param', 'length'])
    df_sst2 = df_sst2.sort_values(by=['top1'], ascending=False).reset_index(drop=True)
    df_sst2.to_pickle('data/df_sst2.pkl')
    
    # mrpc
    data = []
    for n in checkpoints_mrpc:
        data.append(infer(n, 'mrpc', 'data/logits/logits_', getf1 = True, device = device))
    df_mrpc = pd.DataFrame(data, columns=['model', 'top1', 'f1', 'mac', 'param', 'length'])
    df_mrpc = df_mrpc.sort_values(by=['top1'], ascending=False).reset_index(drop=True)
    df_mrpc.to_pickle('data/df_mrpc.pkl')
    
    # qqp
    data = []
    for n in checkpoints_qqp:
        data.append(infer(n, 'qqp', 'data/logits/logits_', getf1 = True, device = device))
    df_qqp = pd.DataFrame(data, columns=['model', 'top1', 'f1', 'mac', 'param', 'length'])
    df_qqp = df_qqp.sort_values(by=['top1'], ascending=False).reset_index(drop=True)
    df_qqp.to_pickle('data/df_qqp.pkl')
    
    # qnli
    data = []
    for n in checkpoints_qnli:
        data.append(infer(n, 'qnli', 'data/logits/logits_', device = device))
    df_qnli = pd.DataFrame(data, columns=['model', 'top1', 'mac', 'param', 'length'])
    df_qnli = df_qnli.sort_values(by=['top1'], ascending=False).reset_index(drop=True)
    df_qnli.to_pickle('data/df_qnli.pkl')


if __name__ == '__main__':
    main()