import json
from pathlib import Path
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
from transformers import BertModel, BertTokenizer, RobertaTokenizer
import random

import util
from data import SNLIDataset, snli_labels, create_collate_fn, sample_and_supplement

from self_explaining_structures.model import ExplainableModel


class BertNLI(nn.Module):
    def __init__(self):
        super().__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.classifier = nn.Linear(768, 3)

    def forward(self, inputs):
        cls = self.bert(**inputs)['pooler_output']
        logits = self.classifier(cls)
        return logits

def train_epoch(model, train_loader, criterion, optimizer, model_dir, val_fn, val_every=None):
    model.train()
    n_total_step = len(train_loader)
    print_metrics = ['loss', 'loss_std', 'acc']
    for i, (_, x, y) in tqdm(enumerate(train_loader), total=n_total_step):
        y_hat = model(x)
        loss_value = criterion(y_hat, y).mean()
        loss_value.backward()
        optimizer.step()
        optimizer.zero_grad()
        if val_every and (i+1) % val_every == 0:
            val_results = val_fn(model)
            val_results_str = ', '.join([f"{m} = {val_results[m]:.3f}" for m in print_metrics])
            print(f"step: {i+1}/{n_total_step}: {val_results_str}")
            
def evaluate(model, val_loader, criterion):
    model.eval()
    n_corrects, loss_values = 0, []
    with torch.no_grad():
        for _, x, y in val_loader:
            logits = model(x)
            y_hat = logits.argmax(axis=1)
            n_corrects += (y_hat == y).sum().item()
            loss_values += criterion(logits, y).tolist()
        loss_values = np.array(loss_values)
    return {
        'loss': loss_values.mean(),
        'loss_std': loss_values.std(),
        'acc': n_corrects/len(val_loader.dataset)
    }


if __name__ == '__main__':

    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("architecture", type=str, choices=['BERT', 'RoBERTa+SE'])
    parser.add_argument("train_data_path", type=Path)
    parser.add_argument("dev_data_path", type=Path)
    parser.add_argument("model_save_dir", type=Path)
    parser.add_argument("--supplementary-data-path", type=Path, default=None)
    parser.add_argument("--supplementary-data-count", type=int, default=50000)
    parser.add_argument("--replace-original-data", action=argparse.BooleanOptionalAction)
    parser.add_argument('--hypothesis-only', action=argparse.BooleanOptionalAction)
    parser.add_argument('--batch-size', type=int, default=8)
    parser.add_argument('--num-epochs', type=int, default=10)
    parser.add_argument('--device', type=str, default='cpu')
    parser.add_argument('--val-every', type=int, default=None, help="Run mid-epoch validation every N batches")
    parser.add_argument('--random-seed', type=int, default=42)
    args = parser.parse_args()

    random.seed(args.random_seed)
    torch.manual_seed(args.random_seed)

    args.model_save_dir.mkdir(parents=True, exist_ok=True)
    json.dump(util.jsonify(args.__dict__), (args.model_save_dir/'args.json').open('w'))

    label_itos = snli_labels
    label_stoi = {l:i for i,l in enumerate(label_itos)}

    train_data = SNLIDataset(args.train_data_path, 'gold_label')
    if args.supplementary_data_path:
        supp_train_data = SNLIDataset(args.supplementary_data_path, 'model_label')
        train_data = sample_and_supplement(train_data, supp_train_data, args.supplementary_data_count, 
                                           replace=args.replace_original_data)
    dev_data = SNLIDataset(args.dev_data_path, 'gold_label')

    match args.architecture:
        case 'BERT':
            model = BertNLI()
            tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
            unfrozen_params = [
                    *model.bert.encoder.layer[10].parameters(),
                    *model.bert.encoder.layer[11].parameters(),
                    *model.classifier.parameters(),
                ]
            optimizer = torch.optim.Adam(unfrozen_params)
        case 'RoBERTa+SE':
            # this part copied & adapted from 
            # https://github.com/ShannonAI/Self_Explaining_Structures_Improve_NLP_Models/blob/master/explain/trainer.py
            model = ExplainableModel('FacebookAI/roberta-base')
            tokenizer = RobertaTokenizer.from_pretrained('FacebookAI/roberta-base')
            adam_epsilon = 10e-8 # according to appendix A
            lr = 2e-5 # not ned in the paper, so using the default in train.py
            weight_decay = 0.01 # from the RoBERTa paper.. seems standard.
            no_decay = ["bias", "LayerNorm.weight"]
            optimizer_grouped_parameters = [
                {
                    "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
                    "weight_decay": weight_decay,
                },
                {
                    "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
                    "weight_decay": 0.0,
                },
            ]
            optimizer = torch.optim.AdamW(optimizer_grouped_parameters,
                          betas=(0.9, 0.98),  # according to RoBERTa paper 
                          lr=lr,
                          eps=adam_epsilon)
 
    
    model.to(args.device)
    criterion = nn.CrossEntropyLoss(reduction='none')

    collate_fn = create_collate_fn(tokenizer, label_stoi, args.device, 
                                   roberta_se=args.architecture=='RoBERTa+SE',
                                   hypothesis_only=args.hypothesis_only)
    train_loader = DataLoader(train_data, args.batch_size, collate_fn=collate_fn, shuffle=True)
    val_loader = DataLoader(dev_data, args.batch_size, collate_fn=collate_fn, shuffle=False)
    val_fn = lambda model: evaluate(model, val_loader, criterion)

    for epoch in range(1, args.num_epochs+1):
        print(f"epoch {epoch}/{args.num_epochs})")
        train_epoch(model, train_loader, criterion, optimizer, args.model_save_dir, val_fn, val_every=args.val_every)
        validation_scores = evaluate(model, val_loader, criterion) | {'epoch': epoch}
        util.write_jsonl(validation_scores, (args.model_save_dir/'val_results.json').open('a'))
        torch.save(model, args.model_save_dir/f'epoch-{epoch}')
