import argparse

import numpy as np
import torch
import torch.nn as nn
from common.utils import make_optimizer, set_seed
from scipy.stats import kendalltau
from sklearn.metrics import r2_score
from torch import optim
from torch.utils.data import DataLoader
from tqdm import tqdm

from model import InvariantNFN
from common.data_utils import SmallMoeZooDatasetAugmented, SmallMoEZooDataset
from common.weight_space import (
    MoEWeightSpaceFeatures,
    LinearWeightSpaceFeatures,
    moe_network_spec_from_wsfeat,
    network_spec_from_wsfeat,
)


@torch.no_grad()
def evaluate(nfn_net, test_loader, loss_fn):
    nfn_net.eval()
    losses,err = [],[]
    pred,actual = [],[]
    for batch in tqdm(test_loader):
        # Move batch to cuda
        embedding, classifier, encoder = batch["embedding"], batch["classifier"], batch["encoder"]
        classifier, encoder = LinearWeightSpaceFeatures(classifier['weight'], classifier['bias']).to("cuda"), MoEWeightSpaceFeatures(**encoder).to("cuda")
        true_acc = batch["accuracy"].cuda()

        # Forward step
        pred_acc = nfn_net(embedding, classifier, encoder)
        # Calculate loss
        err.append(torch.abs(pred_acc - true_acc).mean().item())
        loss = loss_fn(pred_acc, true_acc).item()
        losses.append(loss)
        pred.append(pred_acc.detach().cpu().numpy())
        actual.append(true_acc.cpu().numpy())
    avg_err, avg_loss = np.mean(err), np.mean(losses)
    actual, pred = np.concatenate(actual), np.concatenate(pred)
    rsq = r2_score(actual, pred)
    tau = kendalltau(actual, pred).correlation
    return avg_err, avg_loss, rsq, tau

def main(args):
    print("Start to load dataset")
    if args.device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"
    else:
        device = args.device

    # train_set = SmallMoeZooDatasetAugmented(data_path=args.data_path, n_heads=args.n_heads, split="train", name = args.dataset, cut_off=args.cut_off, load_cache= args.load_cache, augment_factor=args.augment_factor)
    train_set = SmallMoEZooDataset(data_path=args.data_path, n_heads=args.n_heads, split="train", name = args.dataset, cut_off=args.cut_off, load_cache= args.load_cache)

    train_loader = DataLoader(dataset=train_set, batch_size=args.batch_size, shuffle=True, num_workers=args.n_workers, collate_fn=train_set.collate_fn)

    val_set = SmallMoeZooDatasetAugmented(data_path=args.data_path, n_heads=args.n_heads, split="val", name = args.dataset, cut_off=args.cut_off, load_cache= args.load_cache, augment_factor=1)
    val_loader = DataLoader(dataset=val_set, batch_size=args.batch_size, shuffle=True, num_workers=args.n_workers, collate_fn=train_set.collate_fn)

    test_set = SmallMoeZooDatasetAugmented(data_path=args.data_path, n_heads=args.n_heads, split="test", name = args.dataset, cut_off=args.cut_off, load_cache= args.load_cache, augment_factor=1)
    test_loader = DataLoader(dataset=test_set, batch_size=args.batch_size, shuffle=True, num_workers=args.n_workers, collate_fn=test_set.collate_fn)
    
    if args.augment_factor > 1:
        augment_set = SmallMoeZooDatasetAugmented(data_path=args.data_path, n_heads=args.n_heads, split="test", name = args.dataset, 
                                                    cut_off=args.cut_off, load_cache= args.load_cache, augment_factor=args.augment_factor, keep_original=args.keep_original)
        augment_loader = DataLoader(dataset=augment_set, batch_size=args.batch_size, shuffle=True, num_workers=args.n_workers, collate_fn=test_set.collate_fn)

    next_train_load = next(iter(train_loader))
    weight_classify = next_train_load['classifier']['weight']
    bias_classify = next_train_load['classifier']['bias']
    classifier_network_spec = network_spec_from_wsfeat(LinearWeightSpaceFeatures(weight_classify, bias_classify).to("cpu"), set_all_dims=True)

    weight_embedding = next_train_load['embedding']['weight']
    bias_embedding = next_train_load['embedding']['bias']
    embedding_network_spec = network_spec_from_wsfeat(LinearWeightSpaceFeatures(weight_embedding, bias_embedding).to("cpu"), set_all_dims=True)

    weight_encoder = next_train_load['encoder']

    encoder_network_spec = moe_network_spec_from_wsfeat(MoEWeightSpaceFeatures(**weight_encoder).to("cpu"), set_all_dims=True)
    nfn_net = (InvariantNFN(embedding_network_spec=embedding_network_spec, classifier_network_spec=classifier_network_spec,
                            encoder_network_spec=encoder_network_spec,
                            classifier_nfn_channels=args.classifier_nfn_channels,
                            MoE_nfn_channels=args.moe_nfn_channels,
                            num_out_classify=args.num_out_classify, num_out_embedding=args.num_out_embedding,
                            num_out_encoder=args.num_out_encoder, init_type=args.init_type,
                            enc_mode=args.enc_mode, cls_mode=args.cls_mode, emb_mode=args.emb_mode, out_dim_inv = args.out_dim_inv
    ))
    print(nfn_net)
    num_params = sum(p.numel() for p in nfn_net.parameters() if p.requires_grad)
    print(f"Total params in NFN: {num_params}.")

    nfn_net.cuda()
    optimizer = make_optimizer(optimizer=args.optimizer, lr=args.lr, wd=args.wd, model=nfn_net)

    # scheduler for linear warmup of lr and then cosine decay to 1e-5
    linear_warmup = optim.lr_scheduler.LinearLR(optimizer, start_factor=1/args.warmup_epochs, end_factor=1.0, total_iters=args.warmup_epochs-1, last_epoch=-1)
    cos_decay     = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=args.epochs-args.warmup_epochs, eta_min=1e-5)


    loss_fn = nn.BCELoss()
    best_rsq, best_tau = -float('inf'), -float('inf')

    for epoch in tqdm(range(args.epochs)):
        nfn_net.train()
        for batch in tqdm(train_loader):
            # Move batch to cuda
            embedding, classifier, encoder = batch["embedding"], batch["classifier"], batch["encoder"]
            optimizer.zero_grad()
            classifier, encoder = LinearWeightSpaceFeatures(classifier['weight'], classifier['bias']).to("cuda"), MoEWeightSpaceFeatures(**encoder).to("cuda")
            true_acc = batch["accuracy"].cuda()
            # Forward step
            pred_acc = nfn_net(embedding, classifier, encoder)
            # Calculate loss
            try:
                loss = loss_fn(pred_acc, true_acc)  # NOTE: Placeholder
            except:
                print(pred_acc)
                print(true_acc)
            # Update the model
            loss.backward()
            optimizer.step()
        theoretical_loss = loss_fn(true_acc, true_acc)  # perfect loss

        train_err, train_loss, train_rsq, train_tau = evaluate(nfn_net, train_loader, loss_fn)
        print(f"Epoch: {epoch}, Train Loss: {train_loss}, Train Tau: {train_tau}, Train err: {train_err}, Train rsq: {train_rsq}")
        
        val_err, val_loss, val_rsq, val_tau = evaluate(nfn_net, val_loader, loss_fn)
        print(f"Epoch: {epoch}, Val Loss: {val_loss}, Val Tau: {val_tau}, Val err: {val_err}, Val rsq: {val_rsq}")

        test_err, test_loss, test_rsq, test_tau = evaluate(nfn_net, test_loader, loss_fn)
        print(f"Epoch: {epoch}, Test Loss: {test_loss}, Test Tau: {test_tau}, Test err: {test_err}, Test rsq: {test_rsq}")

        if args.augment_factor > 1:
            augment_err, augment_loss, augment_rsq, augment_tau = evaluate(nfn_net, augment_loader, loss_fn)
            print(f"Epoch: {epoch}, Augment Loss: {augment_loss}, Augment Tau: {augment_tau}, Augment err: {augment_err}, Augment rsq: {augment_rsq}")

        # Update learning rate using schedulers
        if epoch < args.warmup_epochs:
            linear_warmup.step()
        else:
            cos_decay.step()

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='MoE nfn training from scratch')

    parser.add_argument('--seed', type=int, default=3, help='random seed')
    parser.add_argument('--device', type=str, default=None)

    # Model Arguments
    parser.add_argument("--ws_dim", type=int, default=10, help="The number weight in the nfn network stacked on each other")
    parser.add_argument("--classifier_dim", type=int, default=1024, help="Dimension of the classification network")


    # Training Arguments
    parser.add_argument('--epochs', type=int, default=100, help='number of training epochs')
    parser.add_argument('--warmup_epochs', type=int, default=10, help='number of epochs to warmup learning rate')
    parser.add_argument('--batch_size', type=int, default=16, help='batch size')
    parser.add_argument('--n_workers', type=int, default=4, help='number of workers for data loaders')
    parser.add_argument('--optimizer', type=str, default="adam", choices=["adam", "sgd", "sgd_momentum", "rmsprop"] ,help='choice of optimizer')
    parser.add_argument('--lr', type=float, default=1e-4, help='peak learning rate')
    parser.add_argument('--wd', type=float, default=0, help="Weight decay factor")
    parser.add_argument('--init_type', type=str, default="xavier_normal", choices=["pytorch_default", "kaiming_normal", "xavier_normal", "xavier_uniform","uniform"], help='Init mode of network parameters')

    # Model Arguments
    parser.add_argument("--classifier_nfn_channels", type=lambda value: [int(x) for x in value.split(',')], default=[100, 100], help="Channels for classifier NFN")
    parser.add_argument("--moe_nfn_channels", type=lambda value: [int(x) for x in value.split(',')], default=[4], help="Channels for MoE NFN")
    parser.add_argument("--num_out_classify", type=int, default=100, help="Number of output classes for classifier")
    parser.add_argument("--num_out_embedding", type=int, default=100, help="Number of output classes for embedding")
    parser.add_argument("--num_out_encoder", type=int, default=100, help="Number of output classes for encoder")
    parser.add_argument("--out_dim_inv", type=int, default=5, help="Dimension of the output of invariant layer")

    parser.add_argument('--enc_mode', default='moe_invariant', choices=['no', 'moe_invariant', "transformer_invariant", 'mlp'])
    parser.add_argument('--cls_mode', default='mlp', choices=['no', 'hnps', 'mlp'])
    parser.add_argument('--emb_mode', default='mlp', choices=['no', 'mlp'])

    # Data arguments
    parser.add_argument('--dataset', type=str, default='mnist', choices=['mnist', "ag_news"], help='dataset to use')
    parser.add_argument('--cut_off', type=float, default=0.1, help='cut off rate for accuracy')
    parser.add_argument('--load_cache', type=bool, default=True, help='load cache')
    parser.add_argument('--n_heads', type=int, default=2, help='Number of heads in MoE input network')
    parser.add_argument('--data_path', type=str, default='/root/repos/MoE_nfn_private/data', help='path to dataset')
    parser.add_argument('--augment_factor', type=int, default=1, help='Augment factor for the dataset based on the group action')
    parser.add_argument('--keep_original', type=int, default=1, help='Choose to keep the original data or not')

    args = parser.parse_args()
    set_seed(manualSeed=args.seed)

    assert any(mode != 'no' for mode in [args.enc_mode, args.cls_mode, args.emb_mode])

    main(args)