import os
import torch
import torch.nn as nn
import torch.nn.functional as F 
from torch.utils.data import DataLoader
from transformers import AutoModel, AdamW, HfArgumentParser, AutoConfig
from torch.optim import SGD, Adam
from eval_data import get_encodings, EvalDataset, subsample
from util import AverageMeter, set_deterministic
from printer import Printer
import argparse
from simcse.models import BertForCL
from train import ModelArguments, DataTrainingArguments, OurTrainingArguments
import sys
from run import write_result

def print_args(dataset_name, optimizer, percent, classification_only, mlp_only_train, theta):
    print("\n=============================\n")
    print(dataset_name, optimizer, percent)
    model_name = "simcse_"+optimizer 
    if mlp_only_train:
        print("Linear Probing")
        return "linear_eval_main", model_name
    elif classification_only:
        print("Finetune")
        return "fintune_eval_main", model_name
    else:
        print("Finetune with contrastive with theta:", theta)
        return "finetune_contrastive_eval_main", model_name+"_"+str(theta)
    

def main():
    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, OurTrainingArguments))
    model_args, data_args, training_args= parser.parse_args_into_dataclasses()
    dataset_name = data_args.dataset_name
    optimizer = training_args.optimizer

    if not os.path.exists("./local"):
        os.mkdir("./local")
    t = Printer(sys.stdout, open(f'./local/{dataset_name}_{optimizer}_{training_args.seed}.txt', 'a')).open()

    classification_only = model_args.classification_only
    mlp_only_train = model_args.mlp_only_train
    method, model_name = print_args(dataset_name, optimizer, data_args.train_percent, \
        classification_only, mlp_only_train, model_args.theta)
    set_deterministic(training_args.seed)

    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

    train_encodings, train_labels, test_encodings, test_labels, class_num = get_encodings(dataset_name)
    train_dataset = EvalDataset(train_encodings, train_labels, classification_only=classification_only)
    sub_size = int(len(train_dataset)*data_args.train_percent)
    train_dataset = subsample(train_dataset, sub_size)

    test_dataset = EvalDataset(test_encodings, test_labels)

    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)
    train_loader_size = len(train_loader)

    config = AutoConfig.from_pretrained(model_args.model_name_or_path)

    # if classification_only:
    #     model = AutoModel.from_pretrained(model_args.model_name_or_path)
    # else:
    model = BertForCL.from_pretrained(
                model_args.model_name_or_path,
                config=config,
                model_args=model_args
            )
    
    model.to(device)
    
    classifier = nn.Linear(in_features=768, out_features=class_num, bias=True).to(device)
    if mlp_only_train:
        net = classifier
        model.eval()
        optim = eval(f"{optimizer}(net.parameters(), lr=5e-3)")
    else:
        net = torch.nn.Sequential(model, classifier)
        optim = eval(f"{optimizer}(net.parameters(), lr=5e-5)")

    train_loss_meter = AverageMeter(name='Train Loss')
    train_acc_meter = AverageMeter(name='Train Accuracy')
    contrastive_loss_meter = AverageMeter(name='Contrastive Loss')
    test_loss_meter = AverageMeter(name='Test Loss')
    test_acc_meter = AverageMeter(name='Test Accuracy')

    net.train()
    for epoch in range(3):
        train_loss_meter.reset()
        train_acc_meter.reset()
        contrastive_loss_meter.reset()
        for idx, batch in enumerate(train_loader):
            optim.zero_grad()
            for key in batch:
                batch[key] = batch[key].to(device)
            labels = batch['labels']

            outputs = model(**batch, output_hidden_states=True, return_dict=True, sent_emb=classification_only)
            if classification_only:
                # outputs = model(batch['input_ids'], attention_mask=batch['attention_mask'])
                preds = classifier(outputs["pooler_output"])
            else:
                # outputs = model(**batch, output_hidden_states=True, return_dict=True, sent_emb=classification_only)
                preds = classifier(outputs["logits"])
      
            loss = F.cross_entropy(preds, labels) # reduction is mean
            train_loss_meter.update(loss.cpu().item(), n=labels.shape[0])

            if not classification_only:
                loss = loss + model_args.theta*outputs["loss"]
                contrastive_loss_meter.update(outputs["loss"].cpu().item(), n=labels.shape[0])
            
            loss.backward()
            optim.step()

            correct = (preds.argmax(dim=1) == labels.to(device)).sum().item()
            train_acc_meter.update(correct/labels.shape[0], n=labels.shape[0])
            
            if idx % 200 ==0:
                print(f"[{epoch}/3][{idx}/{train_loader_size}]", f"contrastive loss:{contrastive_loss_meter.avg:.4}", \
                    f"train loss:{train_loss_meter.avg:.4}", f"train acc:{train_acc_meter.avg:.4%}")

    net.eval()
    test_loss_meter.reset()
    test_acc_meter.reset()
    for batch in test_loader:
        for key in batch:
            batch[key] = batch[key].to(device)
        labels = batch['labels']

        # if classification_only:
        #     outputs = model(batch['input_ids'], attention_mask=batch['attention_mask'])
        #     preds = classifier(outputs["pooler_output"])
        # else:
        outputs = model(**batch, output_hidden_states=True, return_dict=True, sent_emb=True)
        preds = classifier(outputs["pooler_output"])

        loss = F.cross_entropy(preds, labels) # reduction is mean

        test_loss_meter.update(loss.item(), n=labels.shape[0])
        correct = (preds.argmax(dim=1) == labels.to(device)).sum().item()
        test_acc_meter.update(correct/labels.shape[0], n=labels.shape[0])
    print("++++++++++++++++++")
    print(f"test loss:{test_loss_meter.avg:.4f}", f"test acc:{test_acc_meter.avg:.2%}")
    print("++++++++++++++++++")
    t.close()
    write_result(training_args.output, [model_name, dataset_name, str(data_args.train_percent), training_args.seed, method, \
                contrastive_loss_meter.avg, train_loss_meter.avg , train_acc_meter.avg, test_loss_meter.avg, 
                test_acc_meter.avg, 0.0])

if __name__ == "__main__":
    main()
