from datasets.function_shift import SinePolyDataset
from datasets.maze_gen_new_simple import evaluate_reward_maze
from datasets.ant_dataset import AntDataset
from models.hyper import HyperLearner
from models.meta_fun import MetaFun
import numpy as np
import torch
import argparse
from datasets.double_pendulum import DoublePendulumTransfer
from utils import create_hyper_network

import matplotlib.pyplot as plt
import time
from utils import get_accuracy_score, reshape_param
from models.subspace import LEO, LinearLEO, CAVIA
from torch.utils.tensorboard import SummaryWriter
import os
import pickle
from collections import defaultdict
import torch.nn.functional as F
import ipdb
from tqdm import tqdm
from datasets.maze_dataset import DatasetGridSimple
from helper.ant_ppo.rl_code.PPO import PPO


parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='sine')
parser.add_argument('--model', type=str, help="leo or maml or implicit")
parser.add_argument('--model-name', type=str, default='test')
parser.add_argument('--use-meta-curvature', type=int, default=0)
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--num-layers', type=int, default=3)

parser.add_argument('--lambda-theta', type=float, default=1.0)

parser.add_argument('--leo-layers', type=int, default=0)

parser.add_argument('--k', type=int, default=2, help="Num dimensions for latent space for LEO")


parser.add_argument('--learnt-v', type=int, default=0)

parser.add_argument('--adapt-top-layers', type=int, default=0)
parser.add_argument('--num-shots', type=int, default=1)

parser.add_argument('--use-cuda', type=int, default=1)

parser.add_argument('--batch-size', type=int, default=4)

parser.add_argument('--support-size', type=int, default=20)
parser.add_argument('--num-classes', type=int, default=5)

parser.add_argument('--data-dir', type=str, default='data')
parser.add_argument('--tensorboard-dir', type=str, default='tensorboard')
parser.add_argument('--checkpoints-dir', type=str, default='checkpoints')


parser.add_argument('--seed', type=int, default=1)

parser.add_argument('--dataset-size', type=int, default=100, help="Size of sine dataset")
parser.add_argument('--N', type=int, default=1, help="N=Dimension-1 of the sines")
parser.add_argument('--use-complex-function', type=int, default=0)

parser.add_argument('--mse-dim', type=int, help="dimension for MSE", default=1)
parser.add_argument('--sine-classification', type=int, default=0)
parser.add_argument('--num-bins', type=int, default=5)
parser.add_argument('--biased-x', type=int, default=0)

parser.add_argument('--shift', type=int, default=1)

parser.add_argument('--eigen-hyper', type=int, default=0)

parser.add_argument('--imgs', type=int, default=0)
parser.add_argument('--traj-length', type=int, default=4)
parser.add_argument('--maze-size', type=int, default=5)

parser.add_argument('--set-encoder', type=str, default='deepsets', help="pointnet, deepsets, set-transformer")
parser.add_argument('--num-hypernet-layers', type=int, default=1, help="Hypernet output layers")

parser.add_argument('--noise-std', type=float, default=0)

args = parser.parse_args()

args.model_name = f"{args.model_name}-supp-sz-{args.support_size}-dataset-sz-{args.dataset_size}-k-{args.k}-num-hypernet-layers-{args.num_hypernet_layers}-seed-{args.seed}"

print(vars(args))

torch.manual_seed(args.seed)
np.random.seed(args.seed)

if not os.path.exists(args.checkpoints_dir):
    os.mkdir(args.checkpoints_dir)

MODEL_DIR = os.path.join(args.checkpoints_dir, args.model_name)
FIGURES_DIR = os.path.join(MODEL_DIR, 'figures')
LOSS_LANDSCAPE_DIR = os.path.join(FIGURES_DIR, 'loss-landscape')

if not os.path.exists(MODEL_DIR):
    os.mkdir(MODEL_DIR)
if not os.path.exists(FIGURES_DIR):
    os.mkdir(FIGURES_DIR)
if not os.path.exists(LOSS_LANDSCAPE_DIR):
    os.mkdir(LOSS_LANDSCAPE_DIR)

MODEL_PATH = os.path.join(MODEL_DIR, 'model.pt')

meta_path = os.path.join(MODEL_DIR, 'metadata.pkl')
pickle.dump({'args': args}, open(meta_path, 'wb'))

writer = SummaryWriter(os.path.join(args.tensorboard_dir, args.model_name))

# TODO: This is for double pendulum
query_imgs = 0 ## Temporary
query_traj_length = None ## Temporary

#device = 'cuda' if (torch.cuda.is_available() and args.use_cuda) else 'cpu'
device = 'cpu'
print("Device", device)

if args.dataset == 'shift':
    TASK = 'regression'
    dset = SinePolyDataset(args.support_size, 1, N=args.N, size=args.dataset_size, mode="", seed=0, noise_std=args.noise_std, shift=args.shift)

    dset_test = SinePolyDataset(S=args.support_size, dim=1, N=args.N, size=args.dataset_size, mode="InDist", seed=1, noise_std=args.noise_std, shift=args.shift) 
    dset_test2 = SinePolyDataset(S=args.support_size, dim=1, N=args.N, size=args.dataset_size, mode="InDist", seed=1, noise_std=0.2, shift=args.shift) 
    dset_test3 = SinePolyDataset(S=args.support_size, dim=1, N=args.N, size=args.dataset_size, mode="InDist", seed=1, noise_std=1.0, shift=args.shift) 
    dset_test4 = SinePolyDataset(S=args.support_size, dim=1, N=args.N, size=args.dataset_size, mode="InDist", seed=1, noise_std=2.0, shift=args.shift) 

    test_sets = [dset_test, dset_test2, dset_test3, dset_test4]
    #test_sets = [dset_test]

    input_dim = 1
    output_dim = 1

    shift_input_dim = 1
    shift_output_dim = 1

    support_loss = 'mse'
    query_loss = 'mse'

elif args.dataset == 'ant' or args.dataset == 'ant-8':
    TASK = 'regression'

    if args.dataset == 'ant':
        agent_path = 'helper/ant_ppo/rl_code/PPO_model.mdl'
        z_dim = 27 + 2
        a_dim = 8
        a_max = 1
        agent = PPO(z_dim, a_dim, a_max, device)
        agent.load_model(agent_path)
    elif args.dataset == 'ant-8':
        agent_path = 'helper/ant_ppo/rl_code_8_legs/PPO_model_8.mdl'
        z_dim = 27 + 8
        a_dim = 8
        a_max = 1
        agent = PPO(z_dim, a_dim, a_max, device)
        agent.load_model(agent_path)


    dset = AntDataset(args.dataset, args.data_dir, 3000)

    dset_test1 = AntDataset(args.dataset, args.data_dir, 1000, noise_std=1.0, test=True)
    dset_test2 = AntDataset(args.dataset, args.data_dir, 1000, noise_std=2.0, test=True)
    dset_test3 = AntDataset(args.dataset, args.data_dir, 1000, noise_std=4.0, test=True)
    dset_test4 = AntDataset(args.dataset, args.data_dir, 1000, noise_std=8.0, test=True)

    test_sets = [dset_test1, dset_test2, dset_test3, dset_test4]

    input_dim = 27 + 8
    output_dim = 27

    shift_input_dim = 27
    shift_output_dim = 8

    support_loss = 'mse'
    query_loss = 'mse'


elif args.dataset == 'double-pendulum':
    assert args.imgs == 0, "Dont Use imgs, but more messy"
    TASK = 'regression'


    pendulum_dset_name = 'multi_single_double_pendulum_seed'
    dset = DoublePendulumTransfer(os.path.join(args.data_dir, f'{pendulum_dset_name}_0.pkl'), args.traj_length, use_img=0, noise_std=0.0, support_size=args.support_size, dataset_size=args.dataset_size)

    # NOISE
    dset_test = DoublePendulumTransfer(os.path.join(args.data_dir, f'{pendulum_dset_name}_1.pkl'), args.traj_length, use_img=0, noise_std=0.4, support_size = 20, dataset_size=args.dataset_size)
    dset_test1 = DoublePendulumTransfer(os.path.join(args.data_dir, f'{pendulum_dset_name}_1.pkl'), args.traj_length, use_img=0, noise_std=1.0, support_size = 20, dataset_size=args.dataset_size)
    dset_test2 = DoublePendulumTransfer(os.path.join(args.data_dir, f'{pendulum_dset_name}_1.pkl'), args.traj_length, use_img=0, noise_std=2.0, support_size = 20, dataset_size=args.dataset_size)
    dset_test3 = DoublePendulumTransfer(os.path.join(args.data_dir, f'{pendulum_dset_name}_1.pkl'), args.traj_length, use_img=0, noise_std=3.0, support_size = 20, dataset_size=args.dataset_size)

    test_sets = [dset_test, dset_test1, dset_test2, dset_test3]

    input_dim = 5
    output_dim = 4

    # TODO: Functional shift
    shift_input_dim = 12
    shift_output_dim = 11

    query_imgs = 0
    query_traj_length = None

    support_loss = 'mse'
    query_loss = 'mse'

elif args.dataset == 'maze':
    TASK = 'classification'
    dset = DatasetGridSimple(T=args.dataset_size, support_size = args.support_size, seed=0, grid_size=args.maze_size)
    dset_test = DatasetGridSimple(T = args.dataset_size, support_size = args.support_size, seed=1, grid_size=args.maze_size)
    test_sets = [dset_test]
    input_dim = 2
    output_dim = 1

    support_loss = 'bce_loss'

    if args.shift:
        shift_input_dim = 4 # State
        shift_output_dim = 4 # Action
        query_loss = 'cross_entropy'


train_loader = torch.utils.data.DataLoader(dset, batch_size=args.batch_size, shuffle=True)

test_loaders = []
for test_set in test_sets:
    test_loader_i = torch.utils.data.DataLoader(test_set, batch_size=args.batch_size, shuffle=True)
    test_loaders.append(test_loader_i)


x_s, y_s, _, _, _ = next(iter(train_loader))
task_dim = x_s.shape[-1] + y_s.shape[-1]

shifter = HyperLearner(shift_input_dim, args.k, 1, args.num_layers, 'relu', loss=query_loss, output_dim=shift_output_dim, imgs=query_imgs, traj_length=query_traj_length).to(device)
hyper_network = create_hyper_network(args.num_hypernet_layers, args.k, shifter.num_params)

if args.model == 'hyper':

    #shifter = HyperLearner(shift_input_dim, args.k, 1, args.num_layers, 'relu', loss=query_loss, output_dim=shift_output_dim, imgs=query_imgs, traj_length=query_traj_length).to(device)
    learner = HyperLearner(input_dim, args.k, 1, args.num_layers, 'relu', loss=support_loss, output_dim=output_dim, num_output_params=shifter.num_params, imgs=args.imgs, traj_length=args.traj_length, encoder=args.set_encoder, num_hypernet_layers=args.num_hypernet_layers).to(device)


elif args.model == 'eigen' or args.model == 'leo' or args.model == 'cavia':
    learner_params =  dict(input_dim=input_dim, k=args.k, num_layers=args.num_layers, output_dim=output_dim, use_imgs=args.imgs, traj_length=args.traj_length, loss=support_loss)

    shift_params =  dict(input_dim=shift_input_dim, k=args.k, num_layers=args.num_layers, output_dim=shift_output_dim, use_imgs=0, traj_length=args.traj_length, loss=query_loss)

    if args.model == 'eigen':
        learner_model = LinearLEO
    elif args.model == 'leo':
        learner_model = LEO
    elif args.model == 'cavia':
        learner_model = CAVIA

    learner = learner_model(**learner_params).to(device)
elif args.model == 'meta-fun':
    learner = MetaFun(input_dim, output_dim, shift_input_dim, shift_output_dim, args.k, 'rbf').to(device)

opt = torch.optim.Adam(list(learner.parameters()) + list(hyper_network.parameters()), lr=1e-3)

print(f"# Model Parameters: {learner.num_params}")


def train(opt, epoch, dloader, phase, label, model, saved_results, evaluate_policy=False):
    if phase == 'train':
        learner.train()
    elif phase == 'test':
        learner.eval()

    mu_loss = 0
    mu_accuracy = 0
    all_zs = []

    total_rewards_imitation = defaultdict(int)
    total_rewards_ppo = defaultdict(int)
    reward_imitation = dict()
    reward_ppo = dict()

    for ix, (x_s_batch, y_s_batch, x_q_batch, y_q_batch, task_params_batch) in enumerate(dloader):

        start_time = time.time()
        num_tasks = x_s_batch.shape[0]

        mse_loss = 0
        aux_loss = torch.Tensor([0]).to(device)

        auto_encoder_loss = torch.Tensor([0]).to(device)
        theta_reg = torch.Tensor([0]).to(device)
        batch_accuracy_score = 0
        total_z_reg = torch.Tensor([0]).to(device)
        all_zs_batch = []

        # Initialize regularization
        mse_loss_vector = []


        batch_rewards_imitation = defaultdict(int)
        batch_rewards_ppo = defaultdict(int)
        ## Adding to memory
        for i in range(num_tasks):
            x_s = x_s_batch[i].to(device)
            x_q = x_q_batch[i].to(device)
            y_s = y_s_batch[i].to(device)
            y_q = y_q_batch[i].to(device)

            task_params_i = task_params_batch[i].to(device)

            if len(y_s.shape) == 1:
                y_s = y_s.unsqueeze(-1)
            if len(y_q.shape) == 1:
                y_q = y_q.unsqueeze(-1)

            # Inner loss
            if args.model == 'maml':
                if args.shift == 1:
                    raise NotImplementedError
                else:
                    predicted_y1 = learner(x_s)
                    l1 = learner.criterion(predicted_y1, y_s)
                    theta_1 = learner.adapt(l1) 

            elif args.model == 'hyper':
                zs = learner.encoder(x_s, y_s).squeeze(0)
                all_zs.append(zs.cpu().data.numpy())

            elif args.model == 'eigen' or args.model == 'leo' or args.model == 'cavia':
                zs = learner.encode(x_s, y_s)
                all_zs.append(zs.cpu().data.numpy())
            
            elif args.model == 'meta-fun':
                predicted_y2 = learner(x_s, y_s, x_q)
            
            if args.model != 'meta-fun':
                theta_1_arr = hyper_network(zs)
                theta_1 = reshape_param(theta_1_arr, shifter.theta_0)

                ## Outer loss
                predicted_y2 = shifter(x_q, theta_1)

            

            query_loss = shifter.criterion(predicted_y2, y_q)

            mse_loss_vector.append(query_loss.detach())

            mse_loss += query_loss

            accuracy_score = 0
            if TASK == 'classification':
                accuracy_score = get_accuracy_score(predicted_y2, y_q)

            batch_accuracy_score += accuracy_score

            ## Evaluate imitation learning policy

        if args.dataset in ['pendulum-imitation', 'maze', 'ant'] and 'test' in phase and evaluate_policy:
            if args.dataset == 'maze':
                actual_support_size = args.maze_size * 2

                reward_imitation = evaluate_reward_maze(args.model, learner, shifter, hyper_network, args.maze_size, actual_support_size, device, ix)
                reward_ppo = {'avg_reward': 1, 'final_reward': 1, 'max_reward': 1}
            elif args.dataset == 'ant':
                pass
                #reward_imitation = evaluate_reward_ant(lambda x : shifter(x, theta_1), 'imitation', task_params_i, device, seed=ix)
                #reward_ppo = evaluate_reward_ant(agent, 'agent', task_params_i, device, seed=ix)

            for k in reward_imitation.keys():
                batch_rewards_imitation[k] += reward_imitation[k]
                batch_rewards_ppo[k] += reward_ppo[k]


        for k in reward_imitation.keys():
            batch_rewards_imitation[k] /= num_tasks
            batch_rewards_ppo[k] /= num_tasks

            total_rewards_imitation[k] += batch_rewards_imitation[k]
            total_rewards_ppo[k] += batch_rewards_ppo[k]

        mse_loss /= num_tasks
        theta_reg /= num_tasks
        total_z_reg /= num_tasks
        auto_encoder_loss /= num_tasks
        batch_accuracy_score /= num_tasks

        aux_loss /= num_tasks

        mse_loss_vector = torch.stack(mse_loss_vector) 

        if phase == 'train':
            loss = mse_loss + aux_loss
            opt.zero_grad()
            loss.backward()
            opt.step()

        mu_accuracy += batch_accuracy_score
        mu_loss += mse_loss.item()
        end_time = time.time()
        speed = end_time - start_time

        if ix % 1 == 0:
            all_zs = [np.round(x, 3) for x in np.array(all_zs)]
            print(f"{str(label).upper()} Epoch: {epoch} Batch {(ix+1) / len(dloader)*100:.1f}%, {str(model).upper()}: {mse_loss.item():.3f} accuracy: {batch_accuracy_score:.3f} Theta reg: {theta_reg.item():.3f} Z_reg: {total_z_reg.item():.3f} Speed/Batch: {speed:.3f}")

    mu_loss /= len(dloader)
    mu_accuracy /= len(dloader)

    for k in reward_imitation.keys():
        total_rewards_imitation[k] /= len(dloader)
        total_rewards_ppo[k] /= len(dloader)

    dataset_id = f"{args.dataset}"
    writer.add_scalar(f'Loss/{dataset_id}/{label}', mu_loss, epoch)
    writer.add_scalar(f'Z_reg/{dataset_id}/{label}', total_z_reg, epoch)
    writer.add_scalar(f'Theta_reg/{dataset_id}/{label}', theta_reg, epoch)
    writer.add_scalar(f'Accuracy/{dataset_id}/{label}', mu_accuracy, epoch)

    if evaluate_policy:
        for k in total_rewards_imitation.keys():
            if k in ['params']:
                continue
            writer.add_scalar(f'Rewards/{dataset_id}/{k}', total_rewards_imitation[k], epoch)
            writer.add_scalar(f'Rewards_ppo/{dataset_id}/{k}', total_rewards_ppo[k], epoch)

    saved_results['mu_loss'].append(mu_loss)
    saved_results['mu_acc'].append(mu_accuracy)

    saved_results['total_rewards_imitation'].append(total_rewards_imitation)
    saved_results['total_rewards_ppo'].append(total_rewards_ppo)

    if epoch % 1 == 0 and TASK != 'classification' and args.dataset == 'shift' and args.model != 'gbds' and args.model != 'meta-fun':
        with torch.no_grad():
            predicted_y1 = learner(x_q)
        figure_path=f"{FIGURES_DIR}/{epoch}_{phase}_{label}.png"
        zs_figure_path=f"{FIGURES_DIR}/zs_{epoch}_{phase}_{label}.png"
        print("Saving figure", figure_path)

        fig, ax = plt.subplots()
        ax.scatter(x_q.cpu().data.numpy()[:,0], predicted_y1.cpu().data.numpy()[:,0], label='before adaptation')
        ax.scatter(x_q.cpu().data.numpy()[:,0], y_q.cpu().data.numpy()[:,0], label='true', color='black')
        ax.scatter(x_q.cpu().data.numpy()[:,0], predicted_y2.cpu().data.numpy()[:,0], label='ours')

        ax.set_ylim(-5,5)

        ax.legend()
        fig.savefig(figure_path)

        fig, ax = plt.subplots()
        all_zs = np.array(all_zs)
        ax.scatter(all_zs[:,0], all_zs[:,1])
        fig.savefig(zs_figure_path)

        plt.close('all')


if __name__ == '__main__':

    saved_results = defaultdict(list)

    test_results_dict_list = []
    for ix in range(len(test_loaders)):
        test_results_dict_list.append(defaultdict(list))


    for epoch in range(1, args.epochs + 1):

        train(opt, epoch, train_loader, 'train', 'Train', model=args.model, saved_results=saved_results)

        to_evaluate_policy  = False
        if epoch % 1 == 0:
            for ix, test_loader in enumerate(test_loaders):
                print(f"Running Test: {ix}")
                saved_results_test_i = test_results_dict_list[ix]
                train(None, epoch, test_loader, f'test{ix}', f'Test{ix}', model=args.model, saved_results=saved_results_test_i, evaluate_policy = to_evaluate_policy)

        if epoch % 1 == 0:
            learner_path = os.path.join(MODEL_DIR, f'learner.pt')
            shifter_path = os.path.join(MODEL_DIR, f'shifter.pt')
            hypernet_path = os.path.join(MODEL_DIR, 'hypernetwork.pt')
            torch.save(learner, learner_path)
            torch.save(shifter, shifter_path)
            torch.save(hyper_network, hypernet_path)

        if epoch % 1 == 0:
            print("Saved model")

            with open(os.path.join(MODEL_DIR, 'saved_results_train.pkl'), 'wb') as f:
                pickle.dump(saved_results, f)

            for ix, saved_results_test_i in enumerate(test_results_dict_list):
                with open(os.path.join(MODEL_DIR, f'saved_results_test{ix}.pkl'), 'wb') as f:
                    pickle.dump(saved_results_test_i, f)

        print("Saved results!")
