import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from datasets.regression_datasets import Sine, Poly
from datasets.ode_dataset import ODEDataset
from datasets.deformable_mass_spring import DeformableMassSpring
from datasets.omni_dataset import OmnipushDataset
from datasets.Omniglot_dataset import OmniglotData


from datasets.MiniImageNet_dataset import MiniImagenetDataset

from models import GBML, VR_MAML
from torch.utils.tensorboard import SummaryWriter
import argparse
import copy
from utils import create_dir
from timeit import default_timer as timer
from collections import defaultdict
import pickle

parser = argparse.ArgumentParser()

parser.add_argument('--model_name', type=str, default="test")
parser.add_argument('--model_type', default=3, type=str, help="model type ID")
parser.add_argument('--ambient', default=8, type=int, help="dim of theta")
parser.add_argument('--dataset', default=0, type=str, help="type of experiment")
parser.add_argument('--epochs', default=1000, type=int, help="number of epochs")
parser.add_argument('--n_s', default=5, type=int, help="support size")
parser.add_argument('--n_tasks', default=1000, type=int, help="number of training tasks")
parser.add_argument('--reg', default=1.0, type=float, help="hessian regularization")
parser.add_argument('--lr0', default=0.1, type=float, help="initial value of alpha")

parser.add_argument('--seed', default=0, type=int, help="seed")
parser.add_argument('--inner-steps', default=1, type=int)

parser.add_argument('--tensorboard-dir', type=str, default='logs')
parser.add_argument('--checkpoints-dir', type=str, default='checkpoints')

parser.add_argument('--resnet', type=int, default=0)
parser.add_argument('--use-batch-norm', type=int, default=0)

parser.add_argument('--img-size', type=int, default=6)

args = parser.parse_args()


def forward(batch, model, phase, device):
    x_s, y_s, x_q, y_q = batch[0], batch[1], batch[2], batch[3]

    x_s = x_s.to(device)
    y_s = y_s.to(device)
    x_q = x_q.to(device)
    y_q = y_q.to(device)

    training_loss_old = 0
    training_loss = 0

    if isinstance(model, VR_MAML):
        theta_old, det_H = model.adapt(x_s, y_s, model.theta_0[0])
        theta, det_H = model.adapt(x_s, y_s, model.theta_0[1])
        y_q_hat_old = model(x_q, theta_old)
        y_q_hat = model(x_q, theta)
        query_loss_old = model.criterion(y_q_hat_old, y_q)
        query_loss = model.criterion(y_q_hat, y_q)
        training_loss_old += query_loss_old
        training_loss += query_loss
    else:
        theta, det_H = model.adapt(x_s, y_s, model.theta_0, steps=args.inner_steps)

        if model.model_type == 'metamix':
            y_q_hat, y_q = model.mix(x_s, y_s, x_q, y_q, theta)
        else:
            y_q_hat = model(x_q, theta)

        query_loss = model.criterion(y_q_hat, y_q)
        training_loss += query_loss


    accuracy = torch.mean((torch.argmax(y_q_hat, -1) == y_q)*1.).detach().cpu().item() if model.loss_type == 1 else 0
    accuracy = torch.mean((torch.argmax(y_q_hat, -1) == torch.argmax(y_q[:, 2:], -1)) * 1.).detach().cpu().item() if model.loss_type == 2 else accuracy

    return training_loss_old, training_loss, query_loss, det_H, accuracy


def fit(learner, opt, epoch, dloader, phase, writer, loss_records=None):

    if phase == 'train':
        learner.train()
    elif phase == 'test':
        learner.eval()

    mu_loss = 0
    mu_acc = 0
    mu_time = 0

    for ix, batch in enumerate(dloader):

        start = timer()
        total_loss = 0
        tot_det_H = 0
        tot_acc = 0

        num_tasks = batch[0].shape[0]
        expected_loss = 0

        if phase == 'train' and isinstance(learner, VR_MAML):
            learner.copy_params()

        for i in range(num_tasks):

            training_loss_old, training_loss, query_loss, det_H, accuracy = forward([batch[j][i] for j in range(len(batch))], learner, phase, device)

            tot_det_H += det_H

            tot_acc += accuracy

            total_loss += query_loss.detach().cpu().item()
            expected_loss += training_loss

        expected_loss /= num_tasks

        if phase == 'train':

            if isinstance(learner, VR_MAML):
                learner.gradient_step(training_loss_old, training_loss)
            elif learner.use_v:
                theta_0_t = learner.theta_0.clone()
                loss = expected_loss
                opt.zero_grad()
                loss.backward()
                opt.step()
                theta_0_t1 = learner.theta_0.clone()
                learner.update_v(batch, theta_0_t, theta_0_t1)
            else:
                loss = expected_loss
                opt.zero_grad()
                loss.backward()
                opt.step()

        end = timer()
        time = (end - start)
        mu_time += time

        total_loss /= num_tasks
        tot_det_H /= num_tasks
        tot_acc /= num_tasks

        mu_loss += total_loss
        mu_acc += tot_acc

        if ix % 10 == 0:
            print(f"{phase.upper()} Epoch: {epoch} Batch: {ix + 1} / {len(dloader)} Loss: {total_loss:.3f} Accuracy: {100*tot_acc:.2f}% Time: {time:.3f}")


    
    mu_loss /= len(dloader)
    mu_acc /= len(dloader)
    mu_time /= len(dloader)
    
    writer.add_scalar(phase + "/final_loss", mu_loss, epoch)
    writer.add_scalar(phase + "/accuracy", mu_acc, epoch)
    writer.add_scalar(phase + "/time", mu_time, epoch)

    if loss_records is not None:
        loss_records['mu_loss'].append(mu_loss)
        loss_records['comp_time'].append(mu_time)
        loss_records['mu_accuracy'].append(mu_acc)

        return loss_records



if __name__ == '__main__':

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Device: {device}")

    n_layers = 4
    n_s = args.n_s #10
    n_q = 100
    n_tasks = args.n_tasks  # number of different sine waves
    batch_size = 2
    ambient_space_dims = args.ambient
    lr_out = 0.001

    EPOCHS = args.epochs#100
    test_frq = 1


    seed = args.seed
    torch.manual_seed(seed)
    np.random.seed(seed)

    model_type_name = args.model_type
    print(args.dataset, model_type_name)

    ## Create directories
    config_name = f"{args.model_name}_{args.dataset}_{args.model_type}_Ns={n_s}_dimz={args.ambient}_Ntasks={args.n_tasks}_inner_steps={args.inner_steps}_reg={args.reg}_resnet={args.resnet}_epochs={args.epochs}_seed={args.seed}"
    writer = SummaryWriter(os.path.join(f'./{args.tensorboard_dir}', args.dataset, config_name))

    CHECKPOINTS_DIR = args.checkpoints_dir
    MODEL_DIR = os.path.join(CHECKPOINTS_DIR, config_name)
    MODEL_PATH = os.path.join(MODEL_DIR, 'model.pt')

    create_dir(CHECKPOINTS_DIR)
    create_dir(MODEL_DIR)

    with open(os.path.join(MODEL_DIR, 'metadata.pkl'), 'wb') as f:
        pickle.dump({'args': args}, f)


    deconv = 0
    n_ways = 0
    loss_type = 0
    if args.dataset == 'sine':
        dset_train = Sine(n_s, n_q, n_tasks, "train", 0.0, device)
        dset_test = Sine(n_s, n_q, n_tasks, "test", 0.0, device)
        cnn = 0  # 0: MLP, 1: CNN
        x_size = 1
        y_size = 1

    elif args.dataset == 'poly':
        degree = ambient_space_dims-1

        #degree = 5
        dset_train = Poly(n_s, n_q, n_tasks, "train", degree, 0.0, device)
        dset_test = Poly(n_s, n_q, n_tasks, "test", degree, 0.0, device)

        cnn = 0
        x_size = 1
        y_size = 1

    elif args.dataset == 'spring':
        dset_train = ODEDataset(n_s, n_tasks, 'train', 'spring', device)
        dset_test = ODEDataset(n_s, n_tasks, 'test', 'spring', device)

        cnn = 0
        x_size = 1
        y_size = 1

    elif args.dataset == 'pendulum':
        dset_train = ODEDataset(n_s, n_tasks, 'train', 'pendulum', device)
        dset_test = ODEDataset(n_s, n_tasks, 'test', 'pendulum', device)

        cnn = 0
        x_size = 4
        y_size = 2

    elif args.dataset == 'deformable':
        dset_train = DeformableMassSpring(n_s, n_q, "train", data_dir='./data')
        dset_test = DeformableMassSpring(n_s, n_q, "test", data_dir='./data')

        cnn = 0
        x_size = 1
        y_size = 2

    elif args.dataset == 'omnipush':
        dset_train = OmnipushDataset(n_s, "train", device=device)
        dset_test = OmnipushDataset(n_s, "test", device=device)

        cnn = 0
        x_size = 3
        y_size = 3
        loss_type = 0

        deconv = 0

    elif args.dataset == 'omniglot':
        n_ways = 5
        n_s = 5

        dset_train = OmniglotData(n_ways, n_s, n_q, 1223, 0, 'train', device)
        dset_test = OmniglotData(n_ways, n_s, n_q, 1223, 0, 'test', device)

        cnn = 1
        x_size = [28, 1]
        y_size = n_ways
        loss_type = 1

    elif args.dataset == 'imagenet':
        n_ways = 5
        
        k_shot = args.n_s
        k_query = 20

        dset_train = MiniImagenetDataset(240, n_ways, k_shot, k_query, 'train', device)
        dset_test = MiniImagenetDataset(100, n_ways, k_shot, k_query, 'test', device)

        if args.resnet == 1:
            imsize = 224
            x_size = [224, 3]
        else:
            imsize = 84
            x_size = [84, 3]

        cnn = 1
        y_size = n_ways
        loss_type = 1

    else:
        raise NotImplementedError("Dataset does not exist")
        pass

    gamma = 0.99

    train_loader = torch.utils.data.DataLoader(dset_train, batch_size=batch_size, num_workers=0, shuffle=True)
    test_loader = torch.utils.data.DataLoader(dset_test, batch_size=batch_size, num_workers=0, shuffle=True)

    xs, ys, xq, yq = next(iter(train_loader))

    conditional = True if args.model_type in ['cavia', 'lava', 'vfml', 'anil'] else False
    hessian = True if args.model_type == "lava" else False
    use_v = False
    metamix = False

    update_layers = 0

    if args.model_type == "vr-maml":
        learner = VR_MAML(x_size, ambient_space_dims, y_size, n_layers, args.lr0, lr_out, gamma, cnn, deconv, loss_type).to(device)
    else:
        learner = GBML(x_size, ambient_space_dims, y_size, n_layers, args.lr0, args.reg, cnn, loss_type, model_type=args.model_type, use_batch_norm=args.use_batch_norm).to(device)

    opt = torch.optim.Adam(learner.parameters(), lr=1e-3)

    test_loss_records = defaultdict(list)
    train_loss_records = defaultdict(list)

    for epoch in tqdm(range(1, args.epochs + 1)):
        train_loss_records = fit(learner, opt, epoch, train_loader, 'train', writer, train_loss_records)
        save_path = os.path.join(f"./{MODEL_DIR}", f"results_train.pkl")

        with open(save_path, 'wb') as f:
            pickle.dump(train_loss_records, f)

        if epoch % test_frq == 0:
            test_loss_records = fit(learner, None, epoch, test_loader, 'test', writer, test_loss_records)
            save_path = os.path.join(f"./{MODEL_DIR}", f"results_test.pkl")

            with open(save_path, 'wb') as f:
                pickle.dump(test_loss_records, f)

        torch.save(learner, MODEL_PATH)

    writer.close()
