import torch
import torch.nn as nn
import numpy as np
import tools
import Model_pointnet
import train_pointnet_params as Params
import scipy.io
import os
import argparse
from data import ModelNetDataset, partialDataset
import tools
from test_pointnet import test
import torch.nn.functional as F

def train_one_iteraton(pc, param, model, optimizer, iteration, rgt, lam, w_lr, svdinf):
    optimizer.zero_grad()
    batch=pc.shape[0]
    point_num = param.sample_num
    if param.dataset != 'partial':
        ###get training data######
        pc1 = torch.autograd.Variable(pc.float().cuda()) #num*3
        gt_rmat = tools.get_sampled_rotation_matrices_by_axisAngle(batch)#batch*3*3
        gt_rmats = gt_rmat.contiguous().view(batch,1,3,3).expand(batch, point_num, 3,3 ).contiguous().view(-1,3,3)
        pc2 = torch.bmm(gt_rmats, pc1.view(-1,3,1))#(batch*point_num)*3*1
        pc2 = pc2.view(batch, point_num, 3) ##batch,p_num,3
    else:
        gt_rmat = rgt.float().cuda()
        pc2 = pc.float().cuda()

    ###network forward########
    out_rmat,out_nd,_,_ = model(pc2.transpose(1,2))   #output [batch(*sample_num),3,3]

    ####compute loss##########
    if param.manifold == 0:
        if param.supervision != 'self':
            if svdinf:
                loss = ((gt_rmat - out_nd.view(-1,3,3)) ** 2).mean()
            else:
                loss = ((gt_rmat - out_rmat) ** 2).mean()
        else:
            if svdinf:
                loss = ((torch.bmm(pc2, out_nd.view(-1,3,3)) - pc1.view(-1, point_num, 3)) ** 2).mean()
            else:
                loss = ((torch.bmm(pc2, out_rmat)- pc1.view(-1,point_num,3))**2).mean()
    else:
        out_9d = Model_pointnet.RPMG.apply(out_nd,w_lr, lam, gt_rmat, iteration)
        if param.supervision != 'self':
            loss = ((gt_rmat-out_9d)**2).sum()
        else:
            loss = ((torch.bmm(pc2, out_9d) - pc1.view(-1, point_num, 3)) ** 2).mean()
        if iteration % 100 == 0 and model.out_rotation_mode == 'ortho6d':
            param.logger.add_scalar('u_norm', out_nd[:,:3].norm(dim=1).mean().item(), iteration)
            param.logger.add_scalar('v_norm', out_nd[:,3:].norm(dim=1).mean().item(), iteration)
            param.logger.add_scalar('u_v_angle', (out_nd[:,:3]*out_nd[:,3:]).sum(dim=1).mean().item(), iteration)

    loss.backward()
    optimizer.step()

    if iteration % 100 == 0:
        param.logger.add_scalar('train_loss', loss.item(), iteration)
        if  param.manifold == 1:
            param.logger.add_scalar('k', w_lr, iteration)
            param.logger.add_scalar('lambda', lam, iteration)
            param.logger.add_scalar('nd_norm', out_nd.norm(dim=1).mean().item(), iteration)

    return loss

        
# pc_lst: [point_num*3]
def train(param, check_num, reg, w_lr_kind, svdinf):

    torch.cuda.set_device(param.device)
    
    print ("####Initiate model AE")
    
    model = Model_pointnet.Model(out_rotation_mode=param.out_rotation_mode, kind=param.model_kind).cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=param.lr)
    if check_num != 0:
        read_path = os.path.join(param.write_weight_folder, "model_%07d.weight"%check_num)
        #read_path = '../experiments/partial_manifold_pointnet3_6d_vu/weight/model_0060000.weight'
        print("Load " + read_path)
        checkpoint = torch.load(read_path)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_iteration = checkpoint['iteration']
    else:
        print('start from beginning')
        start_iteration = param.start_iteration

    print ("start train")
    if param.dataset == 'partial':
        print("use %s Dataset" % (param.dataset))
        train_folder = os.path.join(param.data_folder, 'train')
        val_folder = os.path.join(param.data_folder, 'test')
        train_dataset = partialDataset(train_folder, sample_num=param.sample_num)
    else:
        print("use ModelNet Dataset" )
        train_folder = os.path.join(param.data_folder, 'train')
        val_folder = os.path.join(param.data_folder, 'test_sampled')
        train_dataset = ModelNetDataset(train_folder, sample_num=param.sample_num)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=param.batch,
        shuffle=True,
        num_workers=4,
        pin_memory=True
    )

    iteration = start_iteration
    #epoch = start_iteration // len(train_loader)
    while True:
        for data, rgt in train_loader:
            model.train()
            lr = max(param.lr * (0.7 ** (iteration // 3000)), 1e-5)
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
            iteration += 1
            if w_lr_kind == 1:
                w_lr = 1/4
            elif w_lr_kind == 2:
                w_lr = 1/20
            elif w_lr_kind == 3:
                w_lr = 1 / 20 + (1 / 4 - 1 / 20) / 9 * min(iteration // (param.total_iteration//10), 9)
            elif w_lr_kind == 4:
                w_lr = -1
            elif w_lr_kind == 5:
                w_lr = 25
            elif w_lr_kind == 6:
                w_lr = 1
            train_loss = train_one_iteraton(data,  param, model, optimizer, iteration, rgt, reg, w_lr, svdinf)
            if (iteration % param.save_weight_iteration == 0):
                print("############# Iteration " + str(iteration) + " #####################")
                print('train loss: ' + str(train_loss.item()))

                model.eval()
                with torch.no_grad():
                    angle_list, val_loss = test(val_folder, model, param.dataset, param.supervision)
                print('val loss: ' + str( val_loss.item()) )
                param.logger.add_scalar('val_loss', val_loss.item(), iteration)
                param.logger.add_scalar('val_median',np.median(angle_list),iteration)
                param.logger.add_scalar('val_mean', angle_list.mean(),iteration)
                param.logger.add_scalar('val_max', angle_list.max(),iteration)
                param.logger.add_scalar('val_5accuracy', (angle_list<5).sum()/len(angle_list), iteration)
                param.logger.add_scalar('val_3accuracy', (angle_list < 3).sum() / len(angle_list), iteration)
                param.logger.add_scalar('val_1accuracy', (angle_list < 1).sum() / len(angle_list), iteration)
                path = param.write_weight_folder + "model_%07d.weight"%iteration
                state = {'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'iteration': iteration}
                torch.save(state, path)

        if iteration >= param.total_iteration:
            break

if __name__ == "__main__":

    arg_parser = argparse.ArgumentParser(
        description="Demo optimization"
    )
    arg_parser.add_argument(
        "--config",
        dest="config",
        type=str,
        required=True,
        help="Path to config",
    )
    arg_parser.add_argument(
        "--check",
        '-c',
        dest="check",
        default=0,
        type=int,
        help="Checkpoint",
    )
    arg_parser.add_argument(
        "--reg",
        '-r',
        default=0.01,
        type=float,
        help="lambda for regularization",
    )
    arg_parser.add_argument(
        "--lrkind",
        default=1,
        type=int,
        help="lr kind for w",
    )
    arg_parser.add_argument(
        "--svdinf",
        "-s",
        default=False,
        type=bool,
        help="lr kind for w",
    )
    param=Params.Parameters()
    args = arg_parser.parse_args()

    param.read_config(os.path.join("../experiments//configs", args.config))
    if param.model_kind == 3:
        print("Model: PointNet++ cls")
    else:
        raise NotImplementedError

    if param.manifold == 0:
        print("Loss: origin")
    else:
        print("Loss: Manifold")

    print('Lambda for Reg = ', args.reg)
    if args.lrkind == 1:
        print('Lr kind for w = 1/4')
    elif args.lrkind == 2:
        print('Lr kind for w = 1/20')
    elif args.lrkind == 3:
        print('Lr kind for w = 1/20->1/4')
    elif args.lrkind == 4:
        print('Lr kind for w = gt')
    elif args.lrkind == 5:
        print('Lr kind for w = 25')
    elif args.lrkind == 6:
        print('Lr kind for w = 1')
    print('Supervision:',param.supervision)
    Model_pointnet.logger_init(param.logger)
    #if not os.path.exists(os.path.dirname(param.write_weight_folder)):
    #    os.makedirs(os.path.dirname(param.write_weight_folder))

    if not os.path.exists(param.write_weight_folder):
        os.makedirs(param.write_weight_folder)

    train(param,args.check, args.reg, args.lrkind,args.svdinf)



