from datasets.ant_dataset import AntDataset
from datasets.contextual_bandit import BanditDataset
from datasets.function_shift import SinePolyDataset
from datasets.maze_gen_new import evaluate_reward_maze
from datasets.pendulum_imitation import PendulumPolicyDataset
from datasets.radius_dataset import RadiusDataset
from models.gbds import MetaModel
from models.hyper import HyperLearner, OutNet
#from models.leo import EigenLEO, LEO, CAVIA
import numpy as np
import torch
import argparse
from datasets.sine_dataset import SineDatasetGeneral
from datasets.pendulum import PendulumDataset
from datasets.double_pendulum import DoublePendulumTransfer
from utils import split
import torch.nn as nn

import matplotlib.pyplot as plt
import time
from utils import flatten_params, get_accuracy_score, reshape_param, flatten_params_torch, rapid_learning_reg, support_to_img
from models.maml import MAMLLearner
#from models.leo import LEO, ConvLEO, SubversiveLEO
#from models.leo import EigenLEO
from models.subspace import LEO, LinearLEO, CAVIA
from torch.utils.tensorboard import SummaryWriter
import os
import pickle
from collections import defaultdict
import torch.nn.functional as F
import ipdb
from tqdm import tqdm
from utils import recompute_eigen_vectors
from utils_pend.pendulum_utils import generate_save_path
from datasets.maze_dataset import DatasetGridSimple
from train_pendulum import evaluate_reward


#torch.autograd.set_detect_anomaly(True)


parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='sine')
parser.add_argument('--model', type=str, help="leo or maml or implicit")
parser.add_argument('--model-name', type=str, default='test')
parser.add_argument('--use-meta-curvature', type=int, default=0)
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--num-layers', type=int, default=3)

parser.add_argument('--lambda-theta', type=float, default=1.0)

parser.add_argument('--leo-layers', type=int, default=0)

parser.add_argument('--k', type=int, default=2, help="Num dimensions for latent space for LEO")


parser.add_argument('--learnt-v', type=int, default=0)

parser.add_argument('--adapt-top-layers', type=int, default=0)
parser.add_argument('--num-shots', type=int, default=1)

parser.add_argument('--use-cuda', type=int, default=1)

parser.add_argument('--batch-size', type=int, default=4)

parser.add_argument('--support-size', type=int, default=20)
parser.add_argument('--num-classes', type=int, default=5)

parser.add_argument('--data-dir', type=str, default='data')
parser.add_argument('--tensorboard-dir', type=str, default='tensorboard')
parser.add_argument('--checkpoints-dir', type=str, default='checkpoints')


parser.add_argument('--seed', type=int, default=1)

parser.add_argument('--dataset-size', type=int, default=100, help="Size of sine dataset")
parser.add_argument('--N', type=int, default=1, help="N=Dimension-1 of the sines")
parser.add_argument('--use-complex-function', type=int, default=0)

parser.add_argument('--mse-dim', type=int, help="dimension for MSE", default=1)
parser.add_argument('--sine-classification', type=int, default=0)
parser.add_argument('--num-bins', type=int, default=5)
parser.add_argument('--biased-x', type=int, default=0)

parser.add_argument('--shift', type=int, default=1)

parser.add_argument('--eigen-hyper', type=int, default=0)

parser.add_argument('--imgs', type=int, default=0)
parser.add_argument('--traj-length', type=int, default=4)
parser.add_argument('--maze-size', type=int, default=5)

parser.add_argument('--set-encoder', type=str, default='deepsets', help="pointnet, deepsets, set-transformer")
parser.add_argument('--num-hypernet-layers', type=int, default=1, help="Hypernet output layers")

parser.add_argument('--noise-std', type=float, default=0)

args = parser.parse_args()

args.model_name = f"task-param-{args.model_name}-supp-sz-{args.support_size}-dataset-sz-{args.dataset_size}-k-{args.k}-seed-{args.seed}"

print(vars(args))

torch.manual_seed(args.seed)
np.random.seed(args.seed)

if not os.path.exists(args.checkpoints_dir):
    os.mkdir(args.checkpoints_dir)

MODEL_DIR = os.path.join(args.checkpoints_dir, args.model_name)
FIGURES_DIR = os.path.join(MODEL_DIR, 'figures')
LOSS_LANDSCAPE_DIR = os.path.join(FIGURES_DIR, 'loss-landscape')

if not os.path.exists(MODEL_DIR):
    os.mkdir(MODEL_DIR)
if not os.path.exists(FIGURES_DIR):
    os.mkdir(FIGURES_DIR)
if not os.path.exists(LOSS_LANDSCAPE_DIR):
    os.mkdir(LOSS_LANDSCAPE_DIR)

MODEL_PATH = os.path.join(MODEL_DIR, 'model.pt')

meta_path = os.path.join(MODEL_DIR, 'metadata.pkl')
pickle.dump({'args': args}, open(meta_path, 'wb'))

writer = SummaryWriter(os.path.join(args.tensorboard_dir, args.model_name))

# TODO: This is for double pendulum
query_imgs = 0 ## Temporary
query_traj_length = None ## Temporary


if args.dataset == 'shift':
    TASK = 'regression'
    dset = SinePolyDataset(args.support_size, 1, N=args.N, size=args.dataset_size, mode="", seed=0, noise_std=args.noise_std, shift=args.shift)

    dset_test = SinePolyDataset(S=args.support_size, dim=1, N=args.N, size=args.dataset_size, mode="InDist", seed=1, noise_std=args.noise_std, shift=args.shift) 
    dset_test2 = SinePolyDataset(S=args.support_size, dim=1, N=args.N, size=args.dataset_size, mode="InDist", seed=1, noise_std=0.2, shift=args.shift) 
    dset_test3 = SinePolyDataset(S=args.support_size, dim=1, N=args.N, size=args.dataset_size, mode="InDist", seed=1, noise_std=1.0, shift=args.shift) 
    dset_test4 = SinePolyDataset(S=args.support_size, dim=1, N=args.N, size=args.dataset_size, mode="InDist", seed=1, noise_std=2.0, shift=args.shift) 

    test_sets = [dset_test, dset_test2, dset_test3, dset_test4]

    input_dim = 1
    output_dim = 1

    shift_input_dim = 1
    shift_output_dim = 1

    support_loss = 'mse'
    query_loss = 'mse'

if args.dataset == 'ant':
    TASK = 'regression'
    dset = AntDataset(args.data_dir, 3000)

    dset_test0 = AntDataset(args.data_dir, 1000, noise_std=0.2, test=True)
    dset_test1 = AntDataset(args.data_dir, 1000, noise_std=1.0, test=True)
    dset_test2 = AntDataset(args.data_dir, 1000, noise_std=2.0, test=True)
    dset_test3 = AntDataset(args.data_dir, 1000, noise_std=4.0, test=True)
    dset_test4 = AntDataset(args.data_dir, 1000, noise_std=8.0, test=True)

    test_sets = [dset_test0, dset_test1, dset_test2, dset_test3, dset_test4]

    input_dim = 27 + 8 # state + action
    output_dim = 27 # state

    shift_input_dim = 27
    shift_output_dim = 8

    support_loss = 'mse'
    query_loss = 'mse'


elif args.dataset == 'double-pendulum':
    assert args.imgs == 0, "Dont Use imgs, but more messy"
    TASK = 'regression'

    pendulum_dset_name = 'multi_single_double_pendulum_seed'
    dset = DoublePendulumTransfer(os.path.join(args.data_dir, f'{pendulum_dset_name}_0.pkl'), args.traj_length, use_img=0, noise_std=0.0, support_size=args.support_size, dataset_size=args.dataset_size)

    # NOISE
    dset_test = DoublePendulumTransfer(os.path.join(args.data_dir, f'{pendulum_dset_name}_1.pkl'), args.traj_length, use_img=0, noise_std=0.4, support_size = 20, dataset_size=args.dataset_size)
    dset_test1 = DoublePendulumTransfer(os.path.join(args.data_dir, f'{pendulum_dset_name}_1.pkl'), args.traj_length, use_img=0, noise_std=0.6, support_size = 20, dataset_size=args.dataset_size)
    dset_test2 = DoublePendulumTransfer(os.path.join(args.data_dir, f'{pendulum_dset_name}_1.pkl'), args.traj_length, use_img=0, noise_std=1.5, support_size = 20, dataset_size=args.dataset_size)
    dset_test3 = DoublePendulumTransfer(os.path.join(args.data_dir, f'{pendulum_dset_name}_1.pkl'), args.traj_length, use_img=0, noise_std=2.0, support_size = 20, dataset_size=args.dataset_size)

    #dset_test4 = DoublePendulumTransfer(os.path.join(args.data_dir, f'{pendulum_dset_name}_1.pkl'), args.traj_length, use_img=0, noise_std=1.5, support_size = 20, dataset_size=args.dataset_size)
    #dset_test5 = DoublePendulumTransfer(os.path.join(args.data_dir, f'{pendulum_dset_name}_1.pkl'), args.traj_length, use_img=0, noise_std=2.0, support_size = 20, dataset_size=args.dataset_size)
    
    # These noise levels were shifted...

    ## Different support
    # dset_test = DoublePendulumTransfer(os.path.join(args.data_dir, 'single_double_pendulum_seed_1.pkl'), args.traj_length, use_img=0, noise_std=3.0, support_size = 5)
    # dset_test1 = DoublePendulumTransfer(os.path.join(args.data_dir, 'single_double_pendulum_seed_1.pkl'), args.traj_length, use_img=0, noise_std=3.0, support_size = 10)
    # dset_test2 = DoublePendulumTransfer(os.path.join(args.data_dir, 'single_double_pendulum_seed_1.pkl'), args.traj_length, use_img=0, noise_std=3.0, support_size = 50)
    # dset_test3 = DoublePendulumTransfer(os.path.join(args.data_dir, 'single_double_pendulum_seed_1.pkl'), args.traj_length, use_img=0, noise_std=3.0, support_size = 100)

    test_sets = [dset_test, dset_test1, dset_test2, dset_test3]

    input_dim = 5
    output_dim = 4

    # TODO: Functional shift
    shift_input_dim = 12
    shift_output_dim = 11

    query_imgs = 0
    #query_traj_length = args.traj_length
    query_traj_length = None

    support_loss = 'mse'
    query_loss = 'mse'


train_loader = torch.utils.data.DataLoader(dset, batch_size=args.batch_size, shuffle=True)

test_loaders = []
for test_set in test_sets:
    test_loader_i = torch.utils.data.DataLoader(test_set, batch_size=args.batch_size, shuffle=True)
    test_loaders.append(test_loader_i)

device = 'cuda' if (torch.cuda.is_available() and args.use_cuda) else 'cpu'
print("Device", device)

x_s, y_s, _, _, task_params = next(iter(train_loader))
task_dim = x_s.shape[-1] + y_s.shape[-1]
num_task_params = task_params.shape[-1]

hidden = 128
#classifier = nn.Sequential(nn.Linear(args.k, hidden), nn.ReLU(), nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, num_task_params)).to(device)
classifier = nn.Linear(args.k, num_task_params).to(device)

if args.model == 'hyper':

    shifter = HyperLearner(shift_input_dim, args.k, 1, args.num_layers, 'relu', loss=query_loss, output_dim=shift_output_dim, imgs=query_imgs, traj_length=query_traj_length).to(device)

    learner = HyperLearner(input_dim, args.k, 1, args.num_layers, 'relu', loss=support_loss, output_dim=output_dim, num_output_params=shifter.num_params, imgs=args.imgs, traj_length=args.traj_length, encoder=args.set_encoder, num_hypernet_layers=args.num_hypernet_layers).to(device)

    opt = torch.optim.Adam(list(learner.parameters()) + list(classifier.parameters()), lr=1e-3)
    out_net = None
    #out_net = OutNet(args.k, args.k).to(device)

elif args.model == 'eigen' or args.model == 'leo' or args.model == 'cavia':
    #hidden = 128
    hidden = args.k
    learner_params =  dict(input_dim=input_dim, k=hidden, num_layers=args.num_layers, output_dim=output_dim, use_imgs=args.imgs, traj_length=args.traj_length, loss=support_loss)

    if args.model == 'eigen':
        learner_model = LinearLEO
    elif args.model == 'leo':
        learner_model = LEO
    elif args.model == 'cavia':
        learner_model = CAVIA

    learner = learner_model(**learner_params).to(device)
    #out_net = OutNet(hidden, args.k).to(device)
    out_net = None

    #opt = torch.optim.Adam(list(learner.parameters()) + list(classifier.parameters()) + list(out_net.parameters()), lr=1e-3)



opt = torch.optim.Adam(list(learner.parameters()) + list(classifier.parameters()), lr=1e-3)

print(f"# Model Parameters: {learner.num_params}")

def train(opt, epoch, dloader, phase, label, model, saved_results, evaluate_policy=False):
    if phase == 'train':
        learner.train()
    elif phase == 'test':
        learner.eval()

    mu_loss = 0
    mu_accuracy = 0
    all_zs = []

    for ix, (x_s_batch, y_s_batch, _, _, task_params_batch) in enumerate(dloader):

        start_time = time.time()
        num_tasks = x_s_batch.shape[0]

        #task_params_batch = task_params_batch

        mse_loss = 0
        theta_reg = torch.Tensor([0]).to(device)
        batch_accuracy_score = 0
        total_z_reg = torch.Tensor([0]).to(device)

        # Initialize regularization
        mse_loss_vector = []

        ## Adding to memory
        for i in range(num_tasks):
            x_s = x_s_batch[i].to(device)
            y_s = y_s_batch[i].to(device)

            task_params_i = task_params_batch[i].to(device)

            # Inner loss

            if args.model == 'hyper':
                zs = learner.encoder(x_s, y_s).squeeze(0)

            elif args.model == 'eigen' or args.model == 'leo' or args.model == 'cavia':
                zs = learner.encode(x_s, y_s).squeeze(0)
                #zs = out_net(zs)
            
            all_zs.append(zs.cpu().data.numpy())

            ## Outer loss
            predicted_y2 = classifier(zs)
            query_loss = F.mse_loss(predicted_y2, task_params_i)

            mse_loss_vector.append(query_loss.detach())

            mse_loss += query_loss

            accuracy_score = 0
            batch_accuracy_score += accuracy_score

            ## Evaluate imitation learning policy


        mse_loss /= num_tasks
        theta_reg /= num_tasks
        total_z_reg /= num_tasks
        batch_accuracy_score /= num_tasks


        mse_loss_vector = torch.stack(mse_loss_vector) 

        if phase == 'train':
            loss = mse_loss
            opt.zero_grad()
            loss.backward()
            opt.step()

        mu_accuracy += batch_accuracy_score
        mu_loss += mse_loss.item()
        end_time = time.time()
        speed = end_time - start_time

        if ix % 1 == 0:
            all_zs = [np.round(x, 3) for x in np.array(all_zs)]
            print(f"{str(label).upper()} Epoch: {epoch} Batch {(ix+1) / len(dloader)*100:.1f}%, {str(model).upper()}: {mse_loss.item():.3f} accuracy: {batch_accuracy_score:.3f} Theta reg: {theta_reg.item():.3f} Z_reg: {total_z_reg.item():.3f} Speed/Batch: {speed:.3f}")

    mu_loss /= len(dloader)
    mu_accuracy /= len(dloader)

    dataset_id = f"{args.dataset}"
    writer.add_scalar(f'Loss/{dataset_id}/{label}', mu_loss, epoch)

    saved_results['mu_loss'].append(mu_loss)
    saved_results['mu_acc'].append(mu_accuracy)

    if epoch % 1 == 0:
        figure_path=f"{FIGURES_DIR}/zs_{epoch}_{phase}_{label}.png"
        print("Saving figure", figure_path)

        fig, ax = plt.subplots()
        all_zs = np.array(all_zs)
        ax.scatter(all_zs[:,0], all_zs[:,1], label='ZS')
        fig.savefig(figure_path)

        plt.close('all')


if __name__ == '__main__':

    saved_results = defaultdict(list)

    test_results_dict_list = []
    for ix in range(len(test_loaders)):
        test_results_dict_list.append(defaultdict(list))


    for epoch in range(1, args.epochs + 1):

        train(opt, epoch, train_loader, 'train', 'Train', model=args.model, saved_results=saved_results)

        if epoch % 1 == 0:
            for ix, test_loader in enumerate(test_loaders):
                print(f"Running Test: {ix}")
                saved_results_test_i = test_results_dict_list[ix]
                train(None, epoch, test_loader, f'test{ix}', f'Test{ix}', model=args.model, saved_results=saved_results_test_i, evaluate_policy = (epoch % 10 == 0))

        if epoch % 1 == 0:
            learner_path = os.path.join(MODEL_DIR, f'learner.pt')
            classifier_path = os.path.join(MODEL_DIR, f'classifier.pt')
            out_net_path = os.path.join(MODEL_DIR, f'out_net.pt')
            torch.save(learner, learner_path)
            torch.save(classifier, classifier_path)
            if out_net is not None:
                if args.model == 'eigen' or args.model == 'leo' or args.model == 'cavia':
                    torch.save(out_net, out_net_path)

        if epoch % 1 == 0:
            print("Saved model")

            with open(os.path.join(MODEL_DIR, 'saved_results_train.pkl'), 'wb') as f:
                pickle.dump(saved_results, f)

            for ix, saved_results_test_i in enumerate(test_results_dict_list):
                with open(os.path.join(MODEL_DIR, f'saved_results_test{ix}.pkl'), 'wb') as f:
                    pickle.dump(saved_results_test_i, f)

        print("Saved results!")
