import time
import torch.optim
import argparse
from model import EnsembleModel
from utils import *
from data.datamgr import SetDataManager
from transformers import AutoTokenizer, AutoConfig, RobertaModel
import os
import torch.nn as nn
from test import test

os.environ["TOKENIZERS_PARALLELISM"] = "false"

def train(base_loader, val_loader, model, optimizer, scheduler, params):
    train_log = {}
    train_log['params'] = vars(params)
    train_log['train_loss'] = []
    train_log['val_loss'] = []
    train_log['train_acc'] = []
    train_log['val_acc'] = []
    train_log['max_val_acc'] = 0.0
    train_log['max_val_acc_epoch'] = 0

    for param in model.module.text_encoder.roberta_model.parameters():
        param.requires_grad_(False)

    for epoch in range(params.epoch):
        epoch_start_time = time.time()

        model.train()

        train_loss, train_acc,_ = model.module.train_loop(base_loader, optimizer, params)

        model.eval()

        val_loss, val_acc,_ = model.module.test_loop(val_loader, params)

        if val_acc > train_log['max_val_acc']:
            print("val best model! save...")
            train_log['max_val_acc'] = val_acc
            train_log['max_val_acc_epoch'] = epoch
            outfile = os.path.join(params.model_dir, 'val_best_model.tar')
            torch.save({'epoch': epoch, 'state': model.state_dict()}, outfile)

        if (epoch + 1) % params.save_freq == 0:
            outfile = os.path.join(params.model_dir, '{:d}.tar'.format(epoch))
            torch.save({'epoch': epoch, 'state': model.state_dict()}, outfile)

        scheduler.step()

        train_log['train_loss'].append(train_loss)
        train_log['val_loss'].append(val_loss)
        train_log['train_acc'].append(train_acc)
        train_log['val_acc'].append(val_acc)

        torch.save(train_log, os.path.join(params.model_dir, 'train_log'))
        epoch_time = time.time() - epoch_start_time
        print("Epoch {:d}/{:d} | Epoch time {:.2f} minutes | Total time {:.2f} hours"
              .format(epoch, params.epoch - 1, (epoch_time) / 60,(params.epoch - epoch)/ 3600 * epoch_time))
        
        print("Train loss {:.2f} | Train acc {:.2f}% | Val loss {:.2f} | Val acc {:.2f}% | "
                .format(train_loss, train_acc, val_loss, val_acc))
        print("Val best acc epoch {:d} | Val best acc {:.2f}%"
              .format(train_log['max_val_acc_epoch'], train_log['max_val_acc']))
        print()

    return model


if __name__ == '__main__':

    params = setup_run()
    set_seed(params.seed)
    base_mgr = SetDataManager(params.dataset,params.data_dir, params.image_size, params.n_way, params.n_shot, params.n_query,
                              params.train_n_episode)
    base_loader = base_mgr.get_data_loader('train', params.train_aug)

    val_mgr = SetDataManager(params.dataset,params.data_dir, params.image_size, params.n_way, params.n_shot, params.n_query,
                             params.val_n_episode)
    val_loader = val_mgr.get_data_loader('val', False)

    roberta_config = AutoConfig.from_pretrained('roberta-base')
    roberta_tokenizer = AutoTokenizer.from_pretrained('roberta-base')
    roberta_model = RobertaModel.from_pretrained('roberta-base', roberta_config)

    model = EnsembleModel(roberta_model, roberta_tokenizer, roberta_config, model_dict[params.backbone])
    pretrain_model_file = os.path.join(params.pretrain_model_dir + '/'+ params.dataset+'/'+ 'pretrain_model.tar')
    load_pretrain_model(model, pretrain_model_file)

    model = nn.DataParallel(model, device_ids=params.device_ids)
    model.cuda()
    optimizer , scheduler= prepare_optimizer(model, params)

    train(base_loader, val_loader, model, optimizer, scheduler, params)

    print("Finish training!")
    
    test(model ,params)


