from datasets.contextual_bandit import BanditDataset
from datasets.function_shift import SinePolyDataset
from datasets.maze_gen_new import evaluate_reward_maze
from datasets.pendulum_imitation import PendulumPolicyDataset
from datasets.radius_dataset import RadiusDataset
from datasets.ant_dataset import AntDataset
from models.gbds import MetaModel
from models.hyper import HyperLearner
#from models.leo import EigenLEO, LEO, CAVIA
import numpy as np
import torch
import argparse
from datasets.sine_dataset import SineDatasetGeneral
from datasets.pendulum import PendulumDataset
from datasets.double_pendulum import DoublePendulumTransfer
from utils import create_hyper_network, split

import matplotlib.pyplot as plt
import time
from utils import flatten_params, get_accuracy_score, reshape_param, flatten_params_torch, rapid_learning_reg, support_to_img
from models.maml import MAMLLearner
#from models.leo import LEO, ConvLEO, SubversiveLEO
#from models.leo import EigenLEO
from models.subspace import LEO, LinearLEO, CAVIA
from torch.utils.tensorboard import SummaryWriter
import os
import pickle
from collections import defaultdict
import torch.nn.functional as F
import ipdb
from tqdm import tqdm
from utils import recompute_eigen_vectors
from utils_pend.pendulum_utils import generate_save_path
from datasets.maze_dataset import DatasetGridSimple
from train_pendulum import evaluate_reward
from helper.ant_ppo.rl_code.PPO import PPO
from helper.ant_ppo.rl_code.evaluate_reward import evaluate_reward_ant


#torch.autograd.set_detect_anomaly(True)


parser = argparse.ArgumentParser()
parser.add_argument('--model-dir', type=str)
parser.add_argument('--checkpoints-dir', type=str, default='checkpoints-ants-reward-local')
parser.add_argument('--tensorboard-dir', type=str, default='tensorboard-ants-reward-local')

args_eval = parser.parse_args()

def load_args(path):
    args_path = os.path.join(path, 'metadata.pkl')
    with open(args_path, 'rb') as f:
        args = pickle.load(f)['args']
    return args

args = load_args(args_eval.model_dir)

args.data_dir = 'data'
#device = 'cuda' if (torch.cuda.is_available() and args.use_cuda) else 'cpu'
device = 'cpu'
print("Device", device)

print(vars(args))

writer = SummaryWriter(os.path.join(args_eval.tensorboard_dir, args.model_name))

torch.manual_seed(args.seed)
np.random.seed(args.seed)

if not os.path.exists(args_eval.checkpoints_dir):
    os.mkdir(args_eval.checkpoints_dir)

SAVED_MODEL_DIR = os.path.join(args_eval.checkpoints_dir, args.model_name)

if not os.path.exists(SAVED_MODEL_DIR):
    os.mkdir(SAVED_MODEL_DIR)

MODEL_DIR = args_eval.model_dir

learner_path = os.path.join(MODEL_DIR, f'learner.pt')
shifter_path = os.path.join(MODEL_DIR, f'shifter.pt')
hypernet_path = os.path.join(MODEL_DIR, 'hypernetwork.pt')

learner = torch.load(learner_path)
shifter = torch.load(shifter_path)
hyper_network = torch.load(hypernet_path)

# Move everything over to new path
torch.save(learner, os.path.join(SAVED_MODEL_DIR, 'learner.pt'))
torch.save(shifter, os.path.join(SAVED_MODEL_DIR, 'shifter.pt'))
torch.save(hyper_network, os.path.join(SAVED_MODEL_DIR, 'hypernetwork.pt'))

# TODO: This is for double pendulum
query_imgs = 0 ## Temporary
query_traj_length = None ## Temporary


if args.dataset == 'shift':
    TASK = 'regression'
    dset = SinePolyDataset(args.support_size, 1, N=args.N, size=args.dataset_size, mode="", seed=0, noise_std=args.noise_std, shift=args.shift)

    dset_test = SinePolyDataset(S=args.support_size, dim=1, N=args.N, size=args.dataset_size, mode="InDist", seed=1, noise_std=args.noise_std, shift=args.shift) 
    dset_test2 = SinePolyDataset(S=args.support_size, dim=1, N=args.N, size=args.dataset_size, mode="InDist", seed=1, noise_std=0.2, shift=args.shift) 
    dset_test3 = SinePolyDataset(S=args.support_size, dim=1, N=args.N, size=args.dataset_size, mode="InDist", seed=1, noise_std=1.0, shift=args.shift) 
    dset_test4 = SinePolyDataset(S=args.support_size, dim=1, N=args.N, size=args.dataset_size, mode="InDist", seed=1, noise_std=2.0, shift=args.shift) 

    test_sets = [dset_test, dset_test2, dset_test3, dset_test4]

    input_dim = 1
    output_dim = 1

    shift_input_dim = 1
    shift_output_dim = 1

    support_loss = 'mse'
    query_loss = 'mse'

elif args.dataset == 'ant' or args.dataset == 'ant-8':
    TASK = 'regression'

    if args.dataset == 'ant':
        agent_path = 'helper/ant_ppo/rl_code/PPO_model.mdl'
        z_dim = 27 + 2
        a_dim = 8
        a_max = 1
        agent = PPO(z_dim, a_dim, a_max, device)
        agent.load_model(agent_path)
    elif args.dataset == 'ant-8':
        agent_path = 'helper/ant_ppo/rl_code_8_legs/PPO_model_8.mdl'
        z_dim = 27 + 8
        a_dim = 8
        a_max = 1
        agent = PPO(z_dim, a_dim, a_max, device)
        agent.load_model(agent_path)


    dset = AntDataset(args.dataset, args.data_dir, 3000)

    dset_test1 = AntDataset(args.dataset, args.data_dir, 1000, noise_std=1.0, test=True)
    dset_test2 = AntDataset(args.dataset, args.data_dir, 1000, noise_std=2.0, test=True)
    dset_test3 = AntDataset(args.dataset, args.data_dir, 1000, noise_std=4.0, test=True)
    dset_test4 = AntDataset(args.dataset, args.data_dir, 1000, noise_std=8.0, test=True)

    test_sets = [dset_test1, dset_test2, dset_test3, dset_test4]

    input_dim = 27 + 8
    output_dim = 27

    shift_input_dim = 27
    shift_output_dim = 8

    support_loss = 'mse'
    query_loss = 'mse'

elif args.dataset == 'pendulum-imitation':
    TASK = 'regression'
    agent = torch.load('ppo_model_ppo3.pt')

    dset = PendulumPolicyDataset(args.data_dir, 'train', args.support_size, args.dataset_size, args.imgs, args.shift, args.traj_length, args.seed, agent)
    dset_test = PendulumPolicyDataset(args.data_dir, 'test', args.support_size, args.dataset_size, args.imgs, args.shift, args.traj_length, args.seed + 1, agent)

    input_dim = 4
    output_dim = 3

    # TODO: Functional shift
    if args.shift:
        shift_input_dim = 3 # State
        shift_output_dim = 1 # Action

    support_loss = 'mse'
    query_loss = 'mse'

elif args.dataset == 'double-pendulum':
    assert args.imgs == 0, "Dont Use imgs, but more messy"
    TASK = 'regression'


    pendulum_dset_name = 'multi_single_double_pendulum_seed'
    dset = DoublePendulumTransfer(os.path.join(args.data_dir, f'{pendulum_dset_name}_0.pkl'), args.traj_length, use_img=0, noise_std=0.0, support_size=args.support_size, dataset_size=args.dataset_size)

    # NOISE
    dset_test = DoublePendulumTransfer(os.path.join(args.data_dir, f'{pendulum_dset_name}_1.pkl'), args.traj_length, use_img=0, noise_std=0.4, support_size = 20, dataset_size=args.dataset_size)
    dset_test1 = DoublePendulumTransfer(os.path.join(args.data_dir, f'{pendulum_dset_name}_1.pkl'), args.traj_length, use_img=0, noise_std=1.0, support_size = 20, dataset_size=args.dataset_size)
    # dset_test2 = DoublePendulumTransfer(os.path.join(args.data_dir, f'{pendulum_dset_name}_1.pkl'), args.traj_length, use_img=0, noise_std=1.5, support_size = 20, dataset_size=args.dataset_size)
    # dset_test3 = DoublePendulumTransfer(os.path.join(args.data_dir, f'{pendulum_dset_name}_1.pkl'), args.traj_length, use_img=0, noise_std=2.0, support_size = 20, dataset_size=args.dataset_size)

    # NOISIER
    dset_test2 = DoublePendulumTransfer(os.path.join(args.data_dir, f'{pendulum_dset_name}_1.pkl'), args.traj_length, use_img=0, noise_std=2.0, support_size = 20, dataset_size=args.dataset_size)
    dset_test3 = DoublePendulumTransfer(os.path.join(args.data_dir, f'{pendulum_dset_name}_1.pkl'), args.traj_length, use_img=0, noise_std=3.0, support_size = 20, dataset_size=args.dataset_size)
    dset_test4 = DoublePendulumTransfer(os.path.join(args.data_dir, f'{pendulum_dset_name}_1.pkl'), args.traj_length, use_img=0, noise_std=4.0, support_size = 20, dataset_size=args.dataset_size)


    test_sets = [dset_test, dset_test1, dset_test2, dset_test3, dset_test4]

    input_dim = 5
    output_dim = 4

    # TODO: Functional shift
    shift_input_dim = 12
    shift_output_dim = 11

    query_imgs = 0
    #query_traj_length = args.traj_length
    query_traj_length = None

    support_loss = 'mse'
    query_loss = 'mse'

elif args.dataset == 'maze':
    TASK = 'classification'
    dset = DatasetGridSimple(T=args.dataset_size, support_size = args.support_size, seed=0, grid_size=args.maze_size)
    dset_test = DatasetGridSimple(T = args.dataset_size, support_size = args.support_size, seed=1, grid_size=args.maze_size)

    # For real dataset
    #input_dim = 5
    #output_dim = 1 #Successfull or not

    # For simple dataset
    input_dim = 2
    output_dim = 1

    support_loss = 'bce_loss'

    # TODO: Functional shift
    if args.shift:
        shift_input_dim = 4 # State
        shift_output_dim = 4 # Action
        query_loss = 'cross_entropy'


test_loaders = []
for test_set in test_sets:
    test_loader_i = torch.utils.data.DataLoader(test_set, batch_size=args.batch_size, shuffle=True)
    test_loaders.append(test_loader_i)

print(f"# Model Parameters: {learner.num_params}")

def train(opt, epoch, dloader, phase, label, model, saved_results, evaluate_policy=False):
    learner.eval()

    mu_loss = 0
    mu_accuracy = 0
    all_zs = []

    total_rewards_imitation = defaultdict(int)
    total_rewards_ppo = defaultdict(int)
    reward_imitation = dict()
    reward_ppo = dict()

    for ix, (x_s_batch, y_s_batch, x_q_batch, y_q_batch, task_params_batch) in enumerate(dloader):


        start_time = time.time()
        num_tasks = x_s_batch.shape[0]

        #task_params_batch = task_params_batch

        mse_loss = 0
        aux_loss = torch.Tensor([0]).to(device)

        auto_encoder_loss = torch.Tensor([0]).to(device)
        theta_reg = torch.Tensor([0]).to(device)
        batch_accuracy_score = 0
        total_z_reg = torch.Tensor([0]).to(device)
        all_zs_batch = []

        # Initialize regularization
        mse_loss_vector = []


        batch_rewards_imitation = defaultdict(int)
        batch_rewards_ppo = defaultdict(int)
        ## Adding to memory
        for i in range(num_tasks):
            x_s = x_s_batch[i].to(device)
            x_q = x_q_batch[i].to(device)
            y_s = y_s_batch[i].to(device)
            y_q = y_q_batch[i].to(device)

            task_params_i = task_params_batch[i].to(device)

            # Inner loss
            if args.model == 'maml':
                if args.shift == 1:
                    raise NotImplementedError
                else:
                    predicted_y1 = learner(x_s)
                    l1 = learner.criterion(predicted_y1, y_s)
                    theta_1 = learner.adapt(l1) 

            elif args.model == 'hyper':
                zs = learner.encoder(x_s, y_s).squeeze(0)
                #theta_1_arr = learner.encode(x_s, y_s)
                #theta_1 = reshape_param(theta_1_arr, shifter.theta_0)

            elif args.model == 'eigen' or args.model == 'leo' or args.model == 'cavia':
                zs = learner.encode(x_s, y_s)
                #print(zs)

            elif args.model == 'meta-fun':
                predicted_y2 = learner(x_s, y_s, x_q)

            if args.model != 'meta-fun':
                theta_1_arr = hyper_network(zs)
                theta_1 = reshape_param(theta_1_arr, shifter.theta_0)

                ## Outer loss
                predicted_y2 = shifter(x_q, theta_1)

            query_loss = shifter.criterion(predicted_y2, y_q)

            mse_loss_vector.append(query_loss.detach())

            mse_loss += query_loss

            accuracy_score = 0
            if TASK == 'classification':
                accuracy_score = get_accuracy_score(predicted_y2, y_q)

            batch_accuracy_score += accuracy_score

            ## Evaluate imitation learning policy

            run_every = 10
            if ix % run_every == 0:
                print(f"Evaluating reward {ix} / {len(dloader)}")
                if args.dataset in ['pendulum-imitation', 'maze', 'ant', 'ant-8'] and 'test' in phase and evaluate_policy:
                    if args.dataset == 'pendulum-imitation':
                        reward_imitation = evaluate_reward(lambda x : shifter(x, theta_1), 'imitation', task_params_i, device, args.seed)
                        reward_ppo = evaluate_reward(agent, 'ppo', task_params_i, device, args.seed)
                    elif args.dataset == 'maze':

                        actual_support_size = (args.maze_size + 2)**2
                        reward_imitation = evaluate_reward_maze(learner, shifter, args.maze_size, actual_support_size, device, ix)
                        reward_ppo = {'avg_reward': 1, 'final_reward': 1, 'max_reward': 1}
                    elif args.dataset == 'ant' or args.dataset == 'ant-8':
                        if args.model == 'meta-fun':
                            lambda_model = lambda x : learner(x_s, y_s, x)
                        else:
                            lambda_model = lambda x : shifter(x, theta_1)

                        reward_imitation = evaluate_reward_ant(args.dataset, lambda_model, 'imitation', task_params_i, device, seed=ix)
                        reward_ppo = evaluate_reward_ant(args.dataset, agent, 'agent', task_params_i, device, seed=ix)

                    for k in reward_imitation.keys():
                        batch_rewards_imitation[k] += reward_imitation[k]
                        batch_rewards_ppo[k] += reward_ppo[k]


        if ix % run_every == 0:
            for k in reward_imitation.keys():
                batch_rewards_imitation[k] /= num_tasks
                batch_rewards_ppo[k] /= num_tasks

                total_rewards_imitation[k] += batch_rewards_imitation[k]
                total_rewards_ppo[k] += batch_rewards_ppo[k]

        mse_loss /= num_tasks
        theta_reg /= num_tasks
        total_z_reg /= num_tasks
        auto_encoder_loss /= num_tasks
        batch_accuracy_score /= num_tasks

        aux_loss /= num_tasks

        mse_loss_vector = torch.stack(mse_loss_vector) 

        if phase == 'train':
            loss = mse_loss + aux_loss
            opt.zero_grad()
            loss.backward()
            opt.step()

        mu_accuracy += batch_accuracy_score
        mu_loss += mse_loss.item()
        end_time = time.time()
        speed = end_time - start_time

        if ix % 1 == 0:
            all_zs = [np.round(x, 3) for x in np.array(all_zs)]
            print(f"{str(label).upper()} Epoch: {epoch} Batch {(ix+1) / len(dloader)*100:.1f}%, {str(model).upper()}: {mse_loss.item():.3f} accuracy: {batch_accuracy_score:.3f} Theta reg: {theta_reg.item():.3f} Z_reg: {total_z_reg.item():.3f} Speed/Batch: {speed:.3f}")

    mu_loss /= len(dloader)
    mu_accuracy /= len(dloader)

    for k in reward_imitation.keys():
        total_rewards_imitation[k] /= (len(dloader) / run_every)
        total_rewards_ppo[k] /= (len(dloader) / run_every)

    
    if evaluate_policy:
        dataset_id = args.dataset
        for k in total_rewards_imitation.keys():
            if k in ['params']:
                continue
            writer.add_scalar(f'Rewards/{dataset_id}/{label}/{k}', total_rewards_imitation[k], epoch)
            writer.add_scalar(f'Rewards_ppo/{dataset_id}/{label}/{k}', total_rewards_ppo[k], epoch)


    saved_results['mu_loss'].append(mu_loss)
    saved_results['mu_acc'].append(mu_accuracy)

    saved_results['total_rewards_imitation'].append(total_rewards_imitation)
    saved_results['total_rewards_ppo'].append(total_rewards_ppo)


if __name__ == '__main__':

    saved_results = defaultdict(list)

    test_results_dict_list = []
    for ix in range(len(test_loaders)):
        test_results_dict_list.append(defaultdict(list))


    to_evaluate_policy = True
    for ix, test_loader in enumerate(test_loaders):
        print(f"Running Test: {ix}")
        saved_results_test_i = test_results_dict_list[ix]
        train(None, 1, test_loader, f'test{ix}', f'Test{ix}', model=args.model, saved_results=saved_results_test_i, evaluate_policy = to_evaluate_policy)

    with open(os.path.join(SAVED_MODEL_DIR, 'saved_results_train.pkl'), 'wb') as f:
        pickle.dump(saved_results, f)

    for ix, saved_results_test_i in enumerate(test_results_dict_list):
        with open(os.path.join(SAVED_MODEL_DIR, f'saved_results_test{ix}.pkl'), 'wb') as f:
            pickle.dump(saved_results_test_i, f)

    print("Saved results!")
