import os
import time
import sys
import copy

import multiprocessing as mp
# from progressbar import ProgressBar

import argparse
import numpy as np
import torch
import torch.nn.functional as F
from torch.autograd import Variable
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader
import matplotlib.pylab as plt

from config import gen_args
from data import Dynamics_Dataset
from data import prepare_input, get_scene_info, get_env_group, prepare_input_boundary_free
from models import Model, ChamferLoss
from utils import make_graph, check_gradient, set_seed, AverageMeter, get_lr, Tee, dy_loss, dy_vis, dy_plot
from utils import count_parameters, my_collate, my_collate_extra, ChamferLoss, capture_scene_image, coalition_loss_batch
from tqdm import tqdm
from chamferdist import ChamferDistance

# ctx = torch.multiprocessing.get_context("spawn")
# ctx.set_start_method('spawn')


args = gen_args()
set_seed(args.random_seed)

# args.n_rollout = 100
# args.time_step = 10
# args.n_his=10
# args.sequence_length = 11
args.outf = args.outf + "_nroll" + str(args.n_rollout) + "_lr" + str(args.lr) + "_alignstep" + str(-args.align_step) + "_mode" + args.mode

# os.system('mkdir -p ' + args.dataf)
os.system('mkdir -p ' + args.outf)
os.system('mkdir  ' + args.outf + "/" + "model")
os.system('mkdir  ' + args.outf + "/" + "curve")

tee = Tee(os.path.join(args.outf, 'train.log'), 'w')


if args.train_extra_steps:
    args.sequence_length += args.train_extra_steps_num

### training

# load training data
time_st_load = time.time()

# if args.multi_step_rolling:
#     args.sequence_length = args.n_his + args.multi_step_num

np.save(args.outf+'/args.npy', args)

phases = ['train', 'valid'] if args.eval == 0 else ['valid']
datasets = {phase: Dynamics_Dataset(args, phase) for phase in phases}

print("DATASET LOADED ---- TIME CMD = ", time.time() - time_st_load)
#
# p=datasets['train'].__getitem__(100)

for phase in phases:
    if args.gen_data:
        datasets[phase].gen_data(args.env)
    else:
        if args.env not in ['pour', 'shake', 'pour_extra', 'shake_extra', 'granular_push']:
            datasets[phase].load_data(args.env)

dataloaders = {phase: DataLoader(
    datasets[phase],
    batch_size=args.batch_size,
    shuffle=True if phase == 'train' else False,
    num_workers=args.num_workers,
    pin_memory=True,
    collate_fn=my_collate_extra if 'extra' in args.env or args.env == 'granular_push' else my_collate) for phase in phases}

# create model and train
use_gpu = torch.cuda.is_available()
model = Model(args, use_gpu)

print("model #params: %d" % count_parameters(model))


# checkpoint to reload model from
model_path = None

# resume training of a saved model (if given)
if args.resume == 0:
    print("Randomly initialize the model's parameters")

elif args.resume == 1:
    model_path = os.path.join(args.outf, 'model/net_epoch_%d_iter_%d.pth' % (
        args.resume_epoch, args.resume_iter))
    print("Loading saved ckp from %s" % model_path)

    if args.stage == 'dy':
        pretrained_dict = torch.load(model_path)
        model_dict = model.state_dict()

        # only load parameters in dynamics_predictor
        pretrained_dict = {
            k: v for k, v in pretrained_dict.items() \
            if 'dynamics_predictor' in k and k in model_dict}
        model.load_state_dict(pretrained_dict, strict=False)


# optimizer
if args.stage == 'dy':
    params = model.dynamics_predictor.parameters()
else:
    raise AssertionError("unknown stage: %s" % args.stage)

if args.optimizer == 'Adam':
    optimizer = torch.optim.Adam(
        params, lr=args.lr, betas=(args.beta1, 0.999))
elif args.optimizer == 'SGD':
    optimizer = torch.optim.SGD(
        params, lr=args.lr, momentum=0.9)
else:
    raise AssertionError("unknown optimizer: %s" % args.optimizer)

# reduce learning rate when a metric has stopped improving
scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.8, patience=3, verbose=True)

# define loss
particle_dist_loss = torch.nn.L1Loss()

if use_gpu:
    model = model.cuda()

# log args
print(args)

# start training
st_epoch = args.resume_epoch if args.resume_epoch > 0 else 0
best_valid_loss = np.inf

# plt loss list:
# - total loss
# - div loss
# - col loss

train_loss_list = []
valid_loss_list = []
train_loss_list_1,  train_loss_list_2, train_loss_list_3 = [], [], []
valid_loss_list_1,  valid_loss_list_2, valid_loss_list_3 = [], [], []
train_col_loss_list, valid_col_list = [], []


train_loss_epc = []
valid_loss_epc = []

best_valid_loss_list = []

crit = ChamferDistance()

pre_seq_len = 1 if not args.train_extra_steps else (1 + args.train_extra_steps_num)

for epoch in range(st_epoch, args.n_epoch):

    for phase in phases:

        print("TOTAL ITER: {}".format(len(dataloaders[phase])))

        model.train(phase == 'train')

        meter_loss_div = [AverageMeter() for i in range(pre_seq_len)]
        meter_loss_sum = AverageMeter()

        if args.coalition_loss:
            meter_loss_col = AverageMeter()



        # bar = ProgressBar(max_value=len(dataloaders[phase]))

        vis_interval = len(dataloaders[phase]) // args.vis
        if vis_interval == 0:
            vis_interval = len(dataloaders[phase]) - 1



        for i, data in tqdm(enumerate(dataloaders[phase])):
            # each "data" is a trajectory of sequence_length time steps
            # if i > 0:
            #     print("LOAD DATA", time.time() - timez_load)
            #
            # print(i)

            if args.time_watch:
                time_st = time.time()
                if i > 0:
                    print(f"time load : {time_st - time_end}")



            time_dy = time.time()
            pred_loss_div = []
            if args.stage == 'dy':
                # attrs: B x (n_p + n_s) x attr_dim
                # particles: B x seq_length x (n_p + n_s) x state_dim
                # n_particles: B
                # n_shapes: B
                # scene_params: B x param_dim
                # Rrs, Rss: B x seq_length x n_rel x (n_p + n_s)
                attrs, particles, n_particles, n_shapes, scene_params, box_info, Rrs, Rss = data

                if use_gpu:
                    attrs = attrs.cuda()
                    particles = particles.cuda()
                    Rrs, Rss = Rrs.cuda(), Rss.cuda()

                # statistics
                B = attrs.size(0)
                n_particle = n_particles[0].item()
                n_shape = n_shapes[0].item()

                # p_rigid: B x n_instance
                # p_instance: B x n_particle x n_instance
                # physics_param: B x n_particle
                groups_gt = get_env_group(args, n_particles, n_shapes, scene_params, use_gpu=use_gpu)

                # memory: B x mem_nlayer x (n_particle + n_shape) x nf_memory
                # for now, only used as a placeholder
                memory_init = None

                with torch.set_grad_enabled(phase == 'train'):
                    # state_cur (unnormalized): B x n_his x (n_p + n_s) x state_dim
                    state_cur = particles[:, :args.n_his]

                    # Rrs_cur, Rss_cur: B x n_rel x (n_p + n_s)
                    Rr_cur = Rrs[:, args.n_his - 1]
                    Rs_cur = Rss[:, args.n_his - 1]

                    # predict the velocity at the next time step
                    inputs = [attrs, state_cur, Rr_cur, Rs_cur, groups_gt]

                    # pred_pos (unnormalized): B x n_p x state_dim
                    # pred_motion_norm (normalized): B x n_p x state_dim
                    pred_pos, pred_motion_norm = model.predict_dynamics(inputs)

                    # concatenate the state of the shapes
                    # pred_pos (unnormalized): B x (n_p + n_s) x state_dim
                    gt_pos = particles[:, args.n_his, :, :3]
                    # pred_pos = torch.cat([pred_pos, gt_pos[:, n_particle:]], 1)

                    # gt_motion_norm (normalized): B x (n_p + n_s) x state_dim
                    # pred_motion_norm (normalized): B x (n_p + n_s) x state_dim

                    if args.time_watch:
                        time_infer = time.time()
                        print(f"infer position : {time_infer - time_st}")

                    if args.align_step == -2:



                        pred_loss, col_loss, loss = dy_loss(gt_pos, pred_pos, n_particles, args, box_info=box_info if 'shake' in args.env else None)
                        pred_loss_div.append(pred_loss.item())

                        if args.time_watch:
                            time_cal_loss = time.time()
                            print(f"cal loss : {time_cal_loss - time_infer}")

                        if phase == 'train':
                            train_loss_list_1.append(pred_loss.item())
                        else:
                            valid_loss_list_1.append(pred_loss.item())

                    if args.train_extra_steps:


                        state_queue = particles[:, :args.n_his]  # B * n_his * (N_P+N_S) * 3

                        for extra_steps_id in range(args.train_extra_steps_num):

                            # update queue of 1. particle states, 2. relations
                            # pred_pos (unnormalized): B x (n_p + n_s) x state_dim

                            if not (args.add_norm_vector or args.boundary_free):

                                state_queue = torch.cat(
                                    [state_queue[:, 1:, :, :],
                                     pred_pos.unsqueeze(1),
                                     ],
                                    1
                                )

                            else:
                                state_queue = torch.cat(
                                    [state_queue[:, 1:],
                                     torch.cat([pred_pos.unsqueeze(1), torch.zeros(pred_pos.shape).unsqueeze(1).to(pred_pos.device)], -1)   # padding back the norm vector dim
                                     ],
                                    1
                                )
                            max_n = state_queue.shape[2]  # max(n_s + n_p)


                            # update gt pos
                            gt_pos = particles[:, args.n_his + extra_steps_id + 1]

                            # build new relations for the last frame
                            # in batch, using prepare_input from .data
                            bs = pred_pos.shape[0]
                            max_n_rel = 0
                            max_n = pred_pos.shape[1]
                            particle, Rrs, Rss = [], [], []  # B * ?
                            for b in range(bs):
                                if not args.boundary_free:
                                    attr_, particle_, Rr_, Rs_ = prepare_input(pred_pos[b], n_particles[b].item(), n_shapes[b], args, var=True)
                                else:
                                    attr_, particle_, Rr_, Rs_ = prepare_input_boundary_free(pred_pos[b], n_particles[b].item(),
                                                                               n_shapes[b], args, var=True, norm=None)
                                particle.append(torch.cat([particle_, torch.zeros(max_n - particle_.shape[0], particle_.shape[1]).to(particle_.device)], 0))
                                Rrs.append(Rr_)
                                Rss.append(Rs_)
                                max_n_rel = max(max_n_rel, Rs_.shape[0])


                            particle = torch.stack(particle)

                            if args.boundary_free:
                                state_queue[:, -1, :] = particle   # B * n_his * (N_P+N_S) * 3

                            ## padding Rs, Rr
                            for o in range(len(Rrs)):
                                Rr, Rs = Rrs[o], Rss[o]

                                Rr = torch.cat([
                                      torch.cat([Rr, torch.zeros(max_n_rel - Rr.size(0), Rr.shape[-1])], 0),
                                      torch.zeros(max_n_rel, max_n - Rr.shape[-1])
                                     ], -1)
                                Rs = torch.cat([
                                    torch.cat([Rs, torch.zeros(max_n_rel - Rs.size(0), Rs.shape[-1])],
                                              0),
                                    torch.zeros(max_n_rel, max_n -  Rs.shape[-1])
                                ], -1)
                                Rrs[o], Rss[o] = Rr, Rs
                            Rr = torch.stack(Rrs).to(state_queue.device)
                            Rs = torch.stack(Rss).to(state_queue.device)
                            inputs = [attrs, state_queue, Rs, Rr, groups_gt]
                            # pred_pos (unnormalied): B x n_p x state_dim
                            # pred_motion_norm (normalized): B x n_p x state_dim
                            pred_pos, pred_motion_norm = model.predict_dynamics(inputs)

                            pred_loss, col_loss, loss_new = dy_loss(gt_pos[..., :3], pred_pos[..., :3], n_particle, args)
                            pred_loss_div.append(pred_loss.item())
                            loss += loss_new


                            # if not args.boundary_free:
                            #
                            #     pred_pos = torch.cat([pred_pos, gt_pos[:, n_particle:]], 1)




                            if args.time_watch:
                                time_extra_infer = time.time()
                                print(f"extra infer: {time_extra_infer - time_infer}")


                            if extra_steps_id == 0:
                                if phase == 'train':
                                    train_loss_list_2.append(pred_loss.item())
                                else:
                                    valid_loss_list_2.append(pred_loss.item())
                            elif extra_steps_id == 1:
                                if phase == 'train':
                                    train_loss_list_3.append(pred_loss.item())
                                else:
                                    valid_loss_list_3.append(pred_loss.item())

                    # gt_motion = particles[:, args.n_his] - particles[:, args.n_his - 1]
                    # mean_d, std_d = model.stat[2:]
                    # gt_motion_norm = (gt_motion - mean_d) / std_d
                    # pred_motion_norm = torch.cat([pred_motion_norm, gt_motion_norm[:, n_particle:]], 1)
                    #
                    # loss = F.l1_loss(pred_motion_norm[:, :n_particle], gt_motion_norm[:, :n_particle])
                    # loss_raw = F.l1_loss(pred_pos, gt_pos)

                    meter_loss_sum.update(loss.item(), B)
                    if args.coalition_loss:
                        meter_loss_col.update(col_loss.item(), B)

                    for o in range(len(meter_loss_div)):
                        meter_loss_div[o].update(pred_loss_div[o])

                # visualization

                if i % vis_interval == 0:
                    vis_pth = os.path.join(args.outf, f"{phase}_vis", str(epoch))
                    if not os.path.exists(vis_pth):
                        os.makedirs(vis_pth)

                    dy_vis(args, vis_pth, pred_pos, n_particles, gt_pos, n_shape, i)

                if i % args.log_per_iter == 0:
                    if args.coalition_loss:
                        print('%s epoch[%d/%d] iter[%d/%d] LR: %.6f, loss: %.6f (%.6f), loss_col: %.8f (%.8f)' % (
                            phase, epoch, args.n_epoch, i, len(dataloaders[phase]), get_lr(optimizer),
                            loss.item(), meter_loss_sum.avg, col_loss.item(), meter_loss_col.avg))
                    else:
                        print('%s epoch[%d/%d] iter[%d/%d] LR: %.6f, loss: %.6f (%.6f)' % (
                            phase, epoch, args.n_epoch, i, len(dataloaders[phase]), get_lr(optimizer),
                            loss.item(), meter_loss_sum.avg))
                time_end = time.time()

            # update model parameters
            if phase == 'train':
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                if args.time_watch:
                    time_bp = time.time()
                    print(f'back propogation: {time_bp - time_extra_infer}')

            if phase == 'train' and i > 0 and i % args.ckp_per_iter == 0:
                model_path = '%s/model/net_epoch_%d_iter_%d.pth' % (args.outf, epoch, i)
                torch.save(model.state_dict(), model_path)

            time_load = time.time()
            # print("TIME DY:", time.time() - time_dy)



        print('%s epoch[%d/%d] Loss: %.6f, Best valid: %.6f' % (
            phase, epoch, args.n_epoch, meter_loss_sum.avg, best_valid_loss))

        if phase == 'train':
            train_loss_epc.append(meter_loss_sum.avg)
        else:
            valid_loss_epc.append(meter_loss_sum.avg)

        if epoch > 0:
            best_valid_loss_list.append(best_valid_loss)

        if phase == 'valid' and not args.eval:
            scheduler.step(meter_loss_sum.avg)
            if meter_loss_sum.avg < best_valid_loss:
                best_valid_loss = meter_loss_sum.avg
                torch.save(model.state_dict(), '%s/net_best.pth' % (args.outf))


        # plot
        dy_plot(train_loss_list, valid_loss_list, best_valid_loss_list, valid_loss_epc, train_loss_epc,
                train_loss_list_1, train_loss_list_2, train_loss_list_3, valid_loss_list_1,
                valid_loss_list_2, valid_loss_list_3, args)


