import torch
import argparse
import numpy as np

from utils import *

def main(*,
    datalen: int = 1000,
    ind_list: str = "[1, 3]",
    max_epoch: int = 10,
):
    fix_seed(0)
    ind_list = eval(ind_list)
    embed_dims = [5120, 5120, 5120, 5120, 5120]
    test_bs = 64
    reconstructed_dataset_val = np.load(f'./data/reconstructed_dataset_ckpt_mn3_train.npy') #[np.load(f'./data/reconstructed_dataset_ckpt_mn{i}_train.npy') for i in range(len(embed_dims))]
    # reconstructed_dataset_val = np.concatenate(reconstructed_dataset_val, axis=1)
    #reconstructed_dataset= np.load('reconstructed_dataset_ckpt_mn3_eval.npy')[:, 10752:10752+4096]
    
    reconstructed_label_val = np.load('./data/reconstructed_label_ckpt_mn0_train.npy')
    

    reconstructed_dataset_test= np.load(f'./data/reconstructed_dataset_ckpt_mn3_test_fewshot.npy') #[np.load(f'./data/reconstructed_dataset_ckpt_mn{i}_test.npy') for i in range(len(embed_dims))]
    # reconstructed_dataset_test = np.concatenate(reconstructed_dataset_test, axis=1)
    reconstructed_label_test = np.load('./data/reconstructed_label_ckpt_mn0_test_fewshot.npy')

    select_id_list = []
    task_emb_train = []
    label_emb_train = []
    task_emb_test= []
    label_emb_test = []
    ind_index = [j for i in ind_list for j in range(5120*i, 5120*(i+1))]
    for i in range(4):
        if i not in ind_list:
            select_id = np.nonzero(reconstructed_label_val==i)[0]
            reconstructed_label_val[select_id] = -1
    
    for idx, i in enumerate(ind_list):
        select_id_train = np.nonzero(reconstructed_label_val==i)[0]
        select_id_test = np.nonzero(reconstructed_label_test==i)[0]
        task_emb_train.append(reconstructed_dataset_val[select_id_train, :][:, ind_index])
        task_emb_test.append(reconstructed_dataset_test[:, ind_index])
        diff = i - idx
        label_emb_train.append(reconstructed_label_val[select_id_train] - diff)
        label_emb_test.append(reconstructed_label_test - diff)
    reconstructed_dataset_val = np.concatenate(task_emb_train, axis=0)
    reconstructed_label_val = np.concatenate(label_emb_train, axis=0)
    
    reconstructed_dataset_test = np.concatenate(task_emb_test, axis=0)
    reconstructed_label_test = np.concatenate(label_emb_test, axis=0)
    
    
    print("reconstructed val dataset size: {}, reconstructed val label size: {}".format(
            reconstructed_dataset_val.shape, reconstructed_label_val.shape
        ))
    print("reconstructed val dataset: {}, reconstructed val label: {}".format(
            reconstructed_label_val, reconstructed_label_val
        ))

    print("reconstructed test dataset size: {}, reconstructed test label size: {}".format(
            reconstructed_dataset_test.shape, reconstructed_label_test.shape
        ))
    print("reconstructed_dataset: {}, reconstructed_label: {}".format(
            reconstructed_dataset_test, reconstructed_label_test
        ))

    device = "cuda:0" if torch.cuda.is_available() else "cpu"

    classifier = SimpleClassifierForSentimentAnalysis(
                num_clients=len(ind_list), embedding_dims=np.array(embed_dims)[ind_list].tolist(), hidden_dim=8192).to(device)
    # classifier = LinearForSentimentAnalysis(
    #                num_clients=16, embedding_dims=embed_dims[0:4], hidden_dim=8192).to(device)
    # classifier = SimpleClassifierForSentimentAnalysis(
    #                         num_clients=16, embedding_dims=[4096], hidden_dim=int(4096*2)
    #                     ).to(device)

    print("* Arch of the new classifier: {}".format(classifier))
    _new_classifier_eps = 10
    new_data_loader = torch.utils.data.DataLoader(SimpleDataset(data=reconstructed_dataset_val, 
                                                    targets=reconstructed_label_val),
                                                    batch_size=64, shuffle=True, num_workers=4)
    new_data_loader_test = torch.utils.data.DataLoader(SimpleDataset(data=reconstructed_dataset_test, 
                                                    targets=reconstructed_label_test),
                                                    batch_size=test_bs, shuffle=False, num_workers=4)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(classifier.parameters(), lr=0.0005, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=_new_classifier_eps)
    # optimizer = torch.optim.SGD(classifier.parameters(), lr=0.005, momentum=0.9, weight_decay=1e-3)
    # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.99)

    def _eval(new_data_loader_test, classifier, ep, test_bs, model_assignment=None):
        test_loss, correct, total = 0, 0, 0
        classifier.eval()
        correct_list = [0 for i in range(len(ind_list))]
        total_list = [0 for i in range(len(ind_list))]
        with torch.no_grad():
            for batch_idx, (data, target) in enumerate(new_data_loader_test):
                data, target = data.float().to(device), target.long().to(device)
                # the ensemble step
                outputs = classifier(data)
                # import pdb;pdb.set_trace()
                _, predicted = outputs.max(1)
                if model_assignment is not None:
                    model_assignment[batch_idx*test_bs:(batch_idx+1)*test_bs] = outputs.data.cpu().numpy()
                total += target.size(0)
                correct += predicted.eq(target).sum().item()
                for i in range(len(ind_list)):
                    correct_list[i] += ((target==i)*(predicted.eq(target))).sum().item()
                    total_list[i] += (target==i).sum().item()
            print("@@ EP: {}, Final {}/{}, Accuracy: {:.2f}%".format(
                    ep, correct, total, correct/total*100.0))
            print("correct list: ", correct_list)
            print("acc prop: ", [correct_list[i] / total_list[i] for i in range(len(ind_list))])
        return model_assignment

    for ep in range(_new_classifier_eps):
        train_loss, correct, total = 0, 0, 0
        classifier.train()
        for batch_idx, (inputs, targets) in enumerate(new_data_loader):
            inputs, targets = inputs.float().to(device), targets.long().to(device)
            optimizer.zero_grad()
            outputs = classifier(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            if batch_idx % 10 == 0:
                print("Ep: {}, Training: {}/{}, Loss: {:.3f} | Acc: {:.3f} ({}/{})".format(
                        ep,
                        batch_idx, len(new_data_loader), train_loss/(batch_idx+1), 
                        100.*correct/total, correct, total))
        if ep % 1 == 0:
            api_assignment = np.zeros((reconstructed_label_test.shape[0], len(ind_list)))
            _eval(new_data_loader_test, classifier, ep=ep, test_bs=test_bs, model_assignment=api_assignment)
        scheduler.step()
        # if ep == 1: break
    api_assignment = np.zeros((reconstructed_label_test.shape[0], len(ind_list)))
    api_assignment = _eval(new_data_loader_test, classifier, ep=ep, test_bs=test_bs, model_assignment=api_assignment)
    with open(f'./data/api_assignment_api1_feat_fewshot_select_{str(ind_list)}.npy', 'wb') as f:
        np.save(f, api_assignment)
    torch.save(classifier, open(f"./data/classifier_fewshot_select_{str(ind_list)}.model", "wb"))

def _main(*,
    datalen: int = 1000,
    ind_list: str = "[1, 3]",
    max_epoch: int = 10,
    select_num: int = 2,
    generate_all: bool = True,
):
    if generate_all == False:
        main(ind_list=ind_list)
    else:
        from itertools import combinations
        all_len = 4
        all_list = [i for i in range(all_len)]
        combinations_list = list(combinations(all_list, select_num))
        for c_list in combinations_list:
            main(ind_list=str(list(c_list)))
    
if __name__ == '__main__':
    import defopt
    try:
        defopt.run(_main)
    except:
        import sys,pdb,bdb
        type, value, tb = sys.exc_info()
        if type == bdb.BdbQuit:
            exit()
        print(type,value)
        pdb.post_mortem(tb)