import os
import re
import json
import gym
import torch
import argparse
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from transformers import BertTokenizer, BertModel
from transformers import BertConfig, BertForSequenceClassification



def custom_collate(batch):
    human_descriptions, labels= zip(*batch)
    return human_descriptions, torch.tensor(labels, device=device)

class DescriptionDataset(Dataset):
    def __init__(self, desc_dict):
        self.human_descriptions = []
        self.label_list = []
        
        for i, partner_name in enumerate(desc_dict):
            for desc in desc_dict[partner_name]:
                self.human_descriptions.append(desc)
                self.label_list.append(i)
    
    def __len__(self):
        return len(self.human_descriptions)
    
    def __getitem__(self, idx):
        human_description = self.human_descriptions[idx]
        label = self.label_list[idx]
        return human_description, label
            
            
def preset(args):
    if args.env_config is None:
        args.env_config = {'layout_name': args.layout}
    if args.model_load is None:
        args.model_load = "diffusion_human_ai/models/%s" % (args.layout)
    if args.desc_load is None:
        if args.diverse_desc:
            args.desc_load = os.path.join(args.model_load, "diverse_descriptions.json")
        else:
            args.desc_load = os.path.join(args.model_load, "descriptions.json")
    if args.bert_path is None:
        args.bert_path = "diffusion_human_ai/models/bert-base-uncased"

    if args.finetuned_bert_save is None:
        args.finetuned_bert_save = os.path.join(args.model_load, "finetuned_bert")

    return args


def get_train_data(args):
    with open(args.desc_load, 'r') as f:
        desc_dict = json.load(f)
        
    return desc_dict['train'], desc_dict['test']


def train(tokenizer, bert_classifier, train_dataloader, eval_dataloader, args):
    bert_classifier = bert_classifier.to(device)
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in bert_classifier.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
        {'params': [p for n, p in bert_classifier.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]
    optimizer = torch.optim.AdamW(optimizer_grouped_parameters, args.lr)
    criterion = nn.CrossEntropyLoss()
    
    for epoch in range(args.n_epochs):
        epoch_loss = 0
        epoch_acc = 0
        for desc_batch, label_batch in train_dataloader:
            torch.cuda.empty_cache()
            
            tokenized_text = tokenizer(desc_batch, max_length=64, add_special_tokens=True, padding=True, truncation=True, return_tensors='pt')
            tokenized_text = tokenized_text.to(device)
            
            output = bert_classifier(**tokenized_text, labels=label_batch)
            y_pred_prob = output[1]
            y_pred_label = torch.argmax(y_pred_prob, dim=1)

            loss = criterion(y_pred_prob.view(-1, args.n_partners), label_batch.view(-1))
            acc = torch.sum(y_pred_label == label_batch).item() 

            optimizer.zero_grad()  
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            epoch_acc += acc
        
        eval_loss = 0
        eval_acc = 0
        for desc_batch, label_batch in eval_dataloader:
            torch.cuda.empty_cache()
            
            tokenized_text = tokenizer(desc_batch, max_length=64, add_special_tokens=True, padding=True, truncation=True, return_tensors='pt')
            tokenized_text = tokenized_text.to(device)
            
            output = bert_classifier(**tokenized_text, labels=label_batch)
            y_pred_prob = output[1]
            y_pred_label = torch.argmax(y_pred_prob, dim=1)
            
            eval_loss += criterion(y_pred_prob.view(-1, args.n_partners), label_batch.view(-1)).item()
            eval_acc += torch.sum(y_pred_label == label_batch).item()

        print(f"Epoch:{epoch + 1}, total_loss:{epoch_loss:.6f}, epoch_acc:{epoch_acc / len(train_dataloader.dataset):.6f}, eval_loss:{eval_loss:.6f}, eval_acc:{eval_acc / len(eval_dataloader.dataset):.6f}")
        
        if (epoch + 1) % args.save_interval == 0:
            bert_classifier.bert.save_pretrained(args.finetuned_bert_save)
            # torch.save(bert_classifier.bert.state_dict(), args.finetuned_bert_save)
            print("Model saved at", args.finetuned_bert_save)
        
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--env', type=str, default='OvercookedMultiEnv-v0')
    parser.add_argument('--layout', type=str, default='crossway')
    parser.add_argument('--env_config', type=json.loads, default=None)
    
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--n_partners', type=int, default=8)
    
    parser.add_argument('--batch_size', type=int, default=80)
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--latent_dim', type=int, default=64)
    parser.add_argument('--n_epochs', type=int, default=1000)
    parser.add_argument('--weight_decay', type=float, default=1e-4)
    parser.add_argument('--hidden_dropout_prob', type=float, default=0.3)
    
    parser.add_argument('--save_interval', type=int, default=20)
    parser.add_argument('--multi_batch', type=bool, default=True)
    
    parser.add_argument('--model_load', type=str, default=None)
    parser.add_argument('--finetuned_bert_save', type=str, default=None)
    parser.add_argument('--desc_load', type=str, default=None)
    parser.add_argument('--bert_path', type=str, default=None)

    parser.add_argument('--diverse_desc', type=bool, default=True)
    parser.add_argument('--finetune_bert', type=bool, default=True)
    
    args = parser.parse_args()
    args = preset(args)
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    train_desc, test_desc = get_train_data(args)

    train_dataset = DescriptionDataset(train_desc)
    train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=custom_collate, drop_last=True)

    eval_dataset = DescriptionDataset(test_desc)
    eval_dataloader = DataLoader(eval_dataset, batch_size=32, shuffle=True, collate_fn=custom_collate, drop_last=True)

    tokenizer = BertTokenizer.from_pretrained(args.bert_path)
    config = BertConfig.from_pretrained(args.bert_path, num_labels=args.n_partners, hidden_dropout_prob=args.hidden_dropout_prob)
    bert = BertForSequenceClassification.from_pretrained(args.bert_path, config=config).to(device)
    
    train(tokenizer, bert, train_dataloader, eval_dataloader, args)