import os
import sys
import argparse
import wandb
import torch
import torch.nn.functional as F
import numpy as np
import scipy
import random

import datasets
import utils
from models import gaussian_kernel

import matplotlib.pyplot as plt

torch.set_default_dtype(torch.float64)
torch.manual_seed(3143)
random.seed(253)
np.random.seed(1145)

def eval_kernel(sol, K, y_onehot):
    preds = K.T @ sol
    loss = (preds - y_onehot).pow(2).mean()

    labels = y_onehot.argmax(-1)
    correct_logit_loss = 0.0
    for idx in range(len(labels)):
        correct_logit_loss += (preds[idx][labels[idx]] - 1).pow(2)
    correct_logit_loss /= labels.shape[0]

    if y_onehot.shape[1] > 1:
        count = torch.sum(y_onehot.argmax(-1) == preds.argmax(-1))
        acc = count / y_onehot.shape[0]
    elif y_onehot.shape[1] == 1 or len(y_onehot.shape) == 1:
        count = torch.sum((y_onehot > 0.5) == (preds > 0.5))
        acc = count / y_onehot.shape[0]
    else:
        acc = 0.0

    return acc, loss, correct_logit_loss

def get_kernel(X_tr, X_te, M, bandwidth, model):
    K = None
    if model == 'gaussian':
        K = gaussian_kernel.gaussian_M(X_tr, X_te, bandwidth, M)
    elif model == 'inner_gaussian':
        K = inner_gaussian_kernel.inner_gaussian_M(X_tr, X_te, bandwidth, M)
    else:
        # not yet implemented
        raise

    return K

def update(X, x, M, sol, args, y_onehot, centering=False, K_train=None,
           return_per_class_agop=False, batch_size=2):
    if args.model == 'gaussian':
        if args.rfm_update == 'agop':
            M, _ = gaussian_kernel.get_agop(X, sol.T, args.bandwidth, M, batch_size=batch_size,
                                         K=K_train, centering=centering, x=x,
                                         return_per_class_agop=return_per_class_agop)
        elif args.rfm_update == 'wagop':
            M, _ = gaussian_kernel.get_wagop(X, sol.T, args.bandwidth, M, y_onehot, batch_size=batch_size,
                                          K=K_train, centering=centering, x=x)
    elif args.model == 'inner_gaussian':
        if args.rfm_update == 'wagop':
            M, _ = inner_gaussian_kernel.get_wagop(X, sol.T, args.bandwidth, M, y_onehot, batch_size=batch_size,
                                          K=K_train, centering=centering, x=x)
        else:
            raise
    else:
        raise

    return M

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--wandb_entity', default='default')
    parser.add_argument('--wandb_proj_name', default='default')
    parser.add_argument('--wandb_offline', default=False, action='store_true')
    parser.add_argument('--group_key', default='', type=str)
    parser.add_argument('--out_dir', default='./wandb')
    parser.add_argument('--data_root', default='./data')

    parser.add_argument('--dataset', default='modular_arithmetic', choices={'modular_arithmetic'})
    parser.add_argument('--operation', '-op', default="x+y")
    parser.add_argument('--prime', '-p', default=61, type=int)
    parser.add_argument('--training_fraction', default=0.5, type=float)

    parser.add_argument('--model', default='gaussian', choices={'gaussian', 'inner_gaussian'})
    parser.add_argument('--bandwidth', default=2.5, type=float)
    parser.add_argument('--ridge', default=0.0, type=float)
    parser.add_argument('--rfm_iters', default=50, type=int)
    parser.add_argument('--rfm_update', default='agop', choices={'agop', 'wagop'})
    parser.add_argument('--agop_power', default=0.5, type=float)
    args = parser.parse_args()

    utils.setup_wandb(wandb, args)

    train_dataset, test_dataset, inp_dim, out_dim = datasets.load_dataset(args)

    X_tr = train_dataset.tensors[0].double()
    y_tr_onehot = train_dataset.tensors[1].double()
    X_te = test_dataset.tensors[0].double()
    y_te_onehot = test_dataset.tensors[1].double()

    n = X_tr.shape[0]
    M = torch.eye(X_tr.shape[1]).double()

    for rfm_iter in range(args.rfm_iters):
        K_train = get_kernel(X_tr, X_tr, M, args.bandwidth, args.model)
        sol = torch.from_numpy(np.linalg.solve(K_train.numpy() + args.ridge * np.eye(n), y_tr_onehot.numpy()))

        acc, loss, correct_logit_loss = eval_kernel(sol, K_train, y_tr_onehot)
        print(f'Round {rfm_iter} Train MSE:\t{loss}')
        print(f'Round {rfm_iter} Train Acc:\t{acc}')
        wandb.log({
            'training/accuracy': acc,
            'training/loss': loss,
            'training/correct_logit_loss': correct_logit_loss,
        }, step=rfm_iter)

        K_test = get_kernel(X_tr, X_te, M, args.bandwidth, args.model)
        acc, loss, correct_logit_loss = eval_kernel(sol, K_test, y_te_onehot)
        print(f'Round {rfm_iter} Test MSE:\t{loss}')
        print(f'Round {rfm_iter} Test Acc:\t{acc}')
        print()

        wandb.log({
            'validation/accuracy': acc,
            'validation/loss': loss,
            'validation/correct_logit_loss': correct_logit_loss,
        }, step=rfm_iter)

        M = update(X_tr, X_tr, M, sol, args, y_tr_onehot, centering=True, K_train=K_train,
                   return_per_class_agop=False, batch_size=2)

        if args.agop_power != 1:
            M = utils.matrix_power(M, args.agop_power, is_torch=True)

if __name__=='__main__':
    main()
