import torch
import argparse
import numpy as np

from utils import *

def main(*,
    datalen: int = 1000,
    indomainlen: int = 4,
    max_epoch: int = 10,
):
    fix_seed(0)
    embed_dims = [5120, 5120, 5120, 5120, 5120]
    embed_dims = embed_dims[:indomainlen]
    test_bs = 64
    reconstructed_dataset_val = np.load(f'./data/reconstructed_dataset_ckpt_mn{len(embed_dims)-1}_train.npy') #[np.load(f'./data/reconstructed_dataset_ckpt_mn{i}_train.npy') for i in range(len(embed_dims))]
    # reconstructed_dataset_val = np.concatenate(reconstructed_dataset_val, axis=1)
    #reconstructed_dataset= np.load('reconstructed_dataset_ckpt_mn3_eval.npy')[:, 10752:10752+4096]
    
    reconstructed_label_val = np.load('./data/reconstructed_label_ckpt_mn0_train.npy')
    split_id = np.nonzero(reconstructed_label_val==indomainlen-1)[0][-1]+1
    reconstructed_label_val = reconstructed_label_val[:split_id]
    reconstructed_dataset_val = reconstructed_dataset_val[:split_id, :5120*indomainlen]

    reconstructed_dataset_test= np.load(f'./data/reconstructed_dataset_ckpt_mn{len(embed_dims)-1}_test_fewshot.npy') #[np.load(f'./data/reconstructed_dataset_ckpt_mn{i}_test.npy') for i in range(len(embed_dims))]
    reconstructed_dataset_test = reconstructed_dataset_test[:, :5120*indomainlen]
    # reconstructed_dataset_test = np.concatenate(reconstructed_dataset_test, axis=1)
    reconstructed_label_test = np.load('./data/reconstructed_label_ckpt_mn0_test_fewshot.npy')

    print("reconstructed val dataset size: {}, reconstructed val label size: {}".format(
            reconstructed_dataset_val.shape, reconstructed_label_val.shape
        ))
    print("reconstructed val dataset: {}, reconstructed val label: {}".format(
            reconstructed_label_val, reconstructed_label_val
        ))

    print("reconstructed test dataset size: {}, reconstructed test label size: {}".format(
            reconstructed_dataset_test.shape, reconstructed_label_test.shape
        ))
    print("reconstructed_dataset: {}, reconstructed_label: {}".format(
            reconstructed_dataset_test, reconstructed_label_test
        ))

    device = "cuda:0" if torch.cuda.is_available() else "cpu"

    classifier = SimpleClassifierForSentimentAnalysis(
                num_clients=len(embed_dims), embedding_dims=embed_dims, hidden_dim=8192).to(device)
    # classifier = LinearForSentimentAnalysis(
    #                num_clients=16, embedding_dims=embed_dims[0:4], hidden_dim=8192).to(device)
    # classifier = SimpleClassifierForSentimentAnalysis(
    #                         num_clients=16, embedding_dims=[4096], hidden_dim=int(4096*2)
    #                     ).to(device)

    print("* Arch of the new classifier: {}".format(classifier))
    _new_classifier_eps = 10
    new_data_loader = torch.utils.data.DataLoader(SimpleDataset(data=reconstructed_dataset_val, 
                                                    targets=reconstructed_label_val),
                                                    batch_size=64, shuffle=True, num_workers=4)
    new_data_loader_test = torch.utils.data.DataLoader(SimpleDataset(data=reconstructed_dataset_test, 
                                                    targets=reconstructed_label_test),
                                                    batch_size=test_bs, shuffle=False, num_workers=4)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(classifier.parameters(), lr=0.0005, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=_new_classifier_eps)
    # optimizer = torch.optim.SGD(classifier.parameters(), lr=0.005, momentum=0.9, weight_decay=1e-3)
    # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.99)

    def eval(new_data_loader_test, classifier, ep, test_bs, model_assignment=None):
        test_loss, correct, total = 0, 0, 0
        classifier.eval()
        correct_list = [0 for i in range(4)]
        total_list = [0 for i in range(4)]
        with torch.no_grad():
            for batch_idx, (data, target) in enumerate(new_data_loader_test):
                data, target = data.float().to(device), target.long().to(device)
                # the ensemble step
                outputs = classifier(data)
                # import pdb;pdb.set_trace()
                _, predicted = outputs.max(1)
                if model_assignment is not None:
                    model_assignment[batch_idx*test_bs:(batch_idx+1)*test_bs] = outputs.data.cpu().numpy()
                total += target.size(0)
                correct += predicted.eq(target).sum().item()
                for i in range(4):
                    correct_list[i] += ((target==i)*(predicted.eq(target))).sum().item()
                    total_list[i] += (target==i).sum().item()
            print("@@ EP: {}, Final {}/{}, Accuracy: {:.2f}%".format(
                    ep, correct, total, correct/total*100.0))
            print("correct list: ", correct_list)
            print("acc prop: ", [correct_list[i] / total_list[i] for i in range(4)])
        return model_assignment

    for ep in range(_new_classifier_eps):
        train_loss, correct, total = 0, 0, 0
        classifier.train()
        for batch_idx, (inputs, targets) in enumerate(new_data_loader):
            inputs, targets = inputs.float().to(device), targets.long().to(device)
            optimizer.zero_grad()
            outputs = classifier(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            if batch_idx % 10 == 0:
                print("Ep: {}, Training: {}/{}, Loss: {:.3f} | Acc: {:.3f} ({}/{})".format(
                        ep,
                        batch_idx, len(new_data_loader), train_loss/(batch_idx+1), 
                        100.*correct/total, correct, total))
        if ep % 1 == 0:
            api_assignment = np.zeros((reconstructed_label_test.shape[0], len(embed_dims)))
            eval(new_data_loader_test, classifier, ep=ep, test_bs=test_bs, model_assignment=api_assignment)
        scheduler.step()
        # if ep == 1: break
    api_assignment = np.zeros((reconstructed_label_test.shape[0], len(embed_dims)))
    api_assignment = eval(new_data_loader_test, classifier, ep=ep, test_bs=test_bs, model_assignment=api_assignment)
    with open(f'./data/api_assignment_api1_feat_fewshot_indomain_len{indomainlen}.npy', 'wb') as f:
        np.save(f, api_assignment)
    torch.save(classifier, open(f"./data/classifier_fewshot_indomain_len{indomainlen}.model", "wb"))

if __name__ == '__main__':
    import defopt
    try:
        defopt.run(main)
    except:
        import sys,pdb,bdb
        type, value, tb = sys.exc_info()
        if type == bdb.BdbQuit:
            exit()
        print(type,value)
        pdb.post_mortem(tb)