import argparse
import torch
from con_formation.dataset_cf import ConForDataset
from cfins import GNN, Baseline, Linear, Linear_dynamics, RF_vel, EGNN_vel_CIs, EGNN_vel, ClofNet, ClofNet_CIs, GMN, GMN_CIs, CFINs, CFINs_diff_invar
import os
from torch import nn, optim
import json
import time
import logging

parser = argparse.ArgumentParser(description='cdif for consensus formation')
parser.add_argument('--exp_name', type=str, default='egnn_vel', metavar='N', help='experiment_name')
parser.add_argument('--batch_size', type=int, default=100, metavar='N',
                    help='input batch size for training (default: 128)')
parser.add_argument('--epochs', type=int, default=1000, metavar='N',
                    help='number of epochs to train (default: 10)')
parser.add_argument('--case', type=int, default=0, metavar='N',
                    help='0,1,2,3,4,5')
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='enables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')
parser.add_argument('--log_interval', type=int, default=1, metavar='N',
                    help='how many batches to wait before logging training status')
parser.add_argument('--test_interval', type=int, default=5, metavar='N',
                    help='how many epochs to wait before logging test')
parser.add_argument('--outf', type=str, default='conformation_logs', metavar='N',
                    help='folder to output')
parser.add_argument('--drop', type=float, default=0, metavar='N',
                    help='dropout rate')
parser.add_argument('--lr', type=float, default=0.001, metavar='N',
                    help='learning rate')
parser.add_argument('--nf', type=int, default=64, metavar='N',
                    help='dimension of hidden layers')
parser.add_argument('--model', type=str, default='cdif', metavar='N',
                    help='available models: cdif, gnn, etc')
parser.add_argument('--attention', type=int, default=0, metavar='N',
                    help='attention in the ae model')
parser.add_argument('--n_layers', type=int, default = 4, metavar='N',
                    help='number of layers for the MPNN')
parser.add_argument('--max_training_samples', type=int, default=6000, metavar='N',
                    help='maximum amount of training samples')
parser.add_argument('--level', type=str, default="medium", metavar='N',
                        help='easy, medium, difficult, hell.')
parser.add_argument('--num_agents', type=int, default=5, metavar='N',
                    help='5, 10')
parser.add_argument('--time_exp', type=int, default=0, metavar='N',
                    help='timing experiment')
parser.add_argument('--weight_decay', type=float, default=1e-12, metavar='N',
                    help='weight decay of training')
parser.add_argument('--div', type=float, default=1, metavar='N',
                    help='timing experiment')
parser.add_argument('--norm_diff', type=eval, default=False, metavar='N',
                    help='normalize_diff')
parser.add_argument('--tanh', type=eval, default=False, metavar='N',
                    help='use tanh')
parser.add_argument('--LR_decay', type=eval, default=True, metavar='N',
                    help='LR_decay')
parser.add_argument('--decay', type=float, default=0.9, metavar='N',
                    help='learning rate decay')
parser.add_argument('--degree', type=int, default=2, metavar='N',
                        help='degree of the TFN and SE3')

parser.add_argument('--viz', action='store_true', help='enable visualization')

    
time_exp_dic = {'time': 0, 'counter': 0}


args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()


device = torch.device("cuda" if args.cuda else "cpu")
print(device)
print(args.level)
loss_mse = nn.MSELoss()

print(args)
try:
    os.makedirs(args.outf)
except OSError:
    pass

try:
    os.makedirs(args.outf + "/" + args.exp_name)
except OSError:
    pass

import torch

def get_velocity_attr(loc, vel, rows, cols):

    diff = loc[cols] - loc[rows]
    norm = torch.norm(diff, p=2, dim=1).unsqueeze(1)
    u = diff/norm
    va, vb = vel[rows] * u, vel[cols] * u
    va, vb = torch.sum(va, dim=1).unsqueeze(1), torch.sum(vb, dim=1).unsqueeze(1)
    return va


def main():
    dataset_train = ConForDataset(partition='train', level=args.level, max_samples=args.max_training_samples, num_agents=args.num_agents)
    loader_train = torch.utils.data.DataLoader(dataset_train, batch_size=args.batch_size, shuffle=True, drop_last=True)

    dataset_val = ConForDataset(partition='val', level=args.level, num_agents=args.num_agents)
    loader_val = torch.utils.data.DataLoader(dataset_val, batch_size=args.batch_size, shuffle=False, drop_last=False)

    dataset_test = ConForDataset(partition='test', level=args.level, num_agents=args.num_agents)
    loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=args.batch_size, shuffle=False, drop_last=False)


    if args.model == 'cdif':
        model = CFINs(in_node_nf=1, in_edge_nf=1, hidden_nf=args.nf, device=device, dropout_rate = args.drop, n_layers=args.n_layers, recurrent=False, norm_diff=False, dim=3, no_infer = False)
    elif args.model == 'cdif_diff_invar':
        model = CFINs_diff_invar(in_node_nf=1, in_edge_nf=1, hidden_nf=args.nf, device=device, dropout_rate = args.drop, n_layers=args.n_layers, recurrent=False, dim=3, no_infer = False, case = args.case)
    elif args.model == 'egnn_vel':
        model = EGNN_vel(in_node_nf=1, in_edge_nf=2, hidden_nf=args.nf, device=device, n_layers=args.n_layers, recurrent=False, norm_diff=args.norm_diff, tanh=args.tanh)
    elif args.model == 'gnn':
        model = GNN(input_dim=6, hidden_nf=args.nf, n_layers=args.n_layers, device=device, recurrent=True)
    elif args.model == 'gmn':
        model = GMN(in_node_nf=1, in_edge_nf=2, hidden_nf=args.nf, n_layers=args.n_layers, device=device, recurrent=False, norm_diff=args.norm_diff, tanh=args.tanh)
    elif args.model == 'clof_vel':
        model = ClofNet(in_node_nf=1, in_edge_nf=2, hidden_nf=args.nf, n_layers=args.n_layers, device=device, recurrent=False, norm_diff=args.norm_diff, tanh=args.tanh)
    elif args.model == 'baseline':
        model = Baseline()
    elif args.model == 'linear_vel':
        model = Linear_dynamics(device=device)
    elif args.model == 'linear':
        model = Linear(6, 3, device=device)
    elif args.model == 'rf_vel':
        model = RF_vel(hidden_nf=args.nf, edge_attr_nf=2, device=device, act_fn=nn.SiLU(), n_layers=args.n_layers)
    elif args.model == 'se3_transformer' or args.model == 'tfn':
        from se3_dynamics.dynamics import OurDynamics as SE3_Transformer
        model = SE3_Transformer(n_particles=args.num_agents, n_dimesnion=3, nf=int(args.nf/args.degree), n_layers=args.n_layers, model=args.model, num_degrees=args.degree, div=1, device=device)
    else:
        raise Exception("Wrong model")
    
    print(model)
    optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    step_size = int(args.epochs // 5)
    if args.LR_decay:
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size, gamma=args.decay, last_epoch=-1)
        
    results = {'epochs': [], 'losess': []}
    best_val_loss = 1e8
    best_test_loss = 1e8
    best_epoch = 0
    for epoch in range(0, args.epochs):
        train(model, optimizer, epoch, loader_train)
        if args.LR_decay:
            scheduler.step()
        if epoch % args.test_interval == 0:
            #train(epoch, loader_train, backprop=False)
            val_loss = train(model, optimizer, epoch, loader_val, backprop=False)
            test_loss = train(model, optimizer, epoch, loader_test, backprop=False)
            results['epochs'].append(epoch)
            results['losess'].append(test_loss)
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                best_test_loss = test_loss
                best_epoch = epoch
            print("*** Best Val Loss: %.5f \t Best Test Loss: %.5f \t Best epoch %d" % (best_val_loss, best_test_loss, best_epoch))

        json_object = json.dumps(results, indent=4)
        with open(args.outf + "/" + args.exp_name + "/losess.json", "w") as outfile:
            outfile.write(json_object)
    return best_val_loss, best_test_loss, best_epoch


def train(model, optimizer, epoch, loader, backprop=True):
    if backprop:
        model.train()
    else:
        model.eval()

    res = {'epoch': epoch, 'loss': 0, 'coord_reg': 0, 'counter': 0}

    for batch_idx, data in enumerate(loader):
        batch_size, num_agents, _ = data[0].size()

        data = [d.to(device) for d in data]
        data = [d.view(-1, d.size(2)) for d in data]
        loc, vel, edge_attr, loc_end, vel_end = data


        edges = loader.dataset.get_edges(batch_size, num_agents)
        edges = [edges[0].detach().to(device), edges[1].detach().to(device)]

        optimizer.zero_grad()

        if args.time_exp:
            torch.cuda.synchronize()
            t1 = time.time()
        
        if args.model.startswith('cdif'): # 'cdif', 'cdif_diff_invar'
            nodes = torch.ones(loc.size(0), 1).to(device)
            loc_pred, vel_pred = model(nodes.detach(), loc, vel, edges, edge_attr.detach(), num_agents=num_agents)
            loc_loss = loss_mse(loc_pred, loc_end)
            loss = loc_loss
        elif args.model == 'gmn':
            nodes = torch.sqrt(torch.sum(vel ** 2, dim=1)).unsqueeze(1).detach()
            rows, cols = edges
            loc_dist = torch.sum((loc[rows] - loc[cols])**2, 1).unsqueeze(1)  # relative distances among locations
            edge_attr = torch.cat([edge_attr, loc_dist], 1).detach()  # concatenate all edge properties
            loc_pred = model(nodes, loc.detach(), edges, vel, edge_attr)
            loss = loss_mse(loc_pred, loc_end)
        elif args.model == 'gnn':
            nodes = torch.cat([loc, vel], dim=1)
            loc_pred = model(nodes, edges, edge_attr)
            loss = loss_mse(loc_pred, loc_end)
        elif args.model == 'egnn_vel':
            nodes = torch.ones(loc.size(0), 1).to(device)
            rows, cols = edges
            loc_dist = torch.sum((loc[rows] - loc[cols])**2, 1).unsqueeze(1)  # relative distances among locations
            edge_attr = torch.cat([edge_attr, loc_dist], 1).detach()  # concatenate all edge properties
            loc_pred = model(nodes, loc.detach(), edges, vel, edge_attr)
            loss = loss_mse(loc_pred, loc_end)
        elif args.model == 'clof_vel':
            nodes = torch.sqrt(torch.sum(vel ** 2, dim=1)).unsqueeze(1).detach()
            rows, cols = edges
            loc_dist = torch.sum((loc[rows] - loc[cols])**2, 1).unsqueeze(1)  # relative distances among locations
            edge_attr = torch.cat([edge_attr, loc_dist], 1).detach()  # concatenate all edge properties
            loc_pred = model(nodes, loc.detach(), edges, vel, edge_attr, n_nodes=num_agents)
            loss = loss_mse(loc_pred, loc_end)
        elif args.model == 'baseline':
            backprop = False
            loc_pred = model(loc)
            loss = loss_mse(loc_pred, loc_end)
        elif args.model == 'linear':
            loc_pred = model(torch.cat([loc, vel], dim=1))
            loss = loss_mse(loc_pred, loc_end)
        elif args.model == 'linear_vel':
            loc_pred = model(loc, vel)
            loss = loss_mse(loc_pred, loc_end)
        elif args.model == 'rf_vel':
            rows, cols = edges
            vel_norm = torch.sqrt(torch.sum(vel ** 2, dim=1).unsqueeze(1)).detach()
            loc_dist = torch.sum((loc[rows] - loc[cols]) ** 2, 1).unsqueeze(1)
            edge_attr = torch.cat([edge_attr, loc_dist], 1).detach()
            loc_pred = model(vel_norm, loc.detach(), edges, vel, edge_attr)
            loss = loss_mse(loc_pred, loc_end)
        elif args.model == 'se3_transformer' or args.model == 'tfn':
            nodes = torch.ones(loc.size(0), 1).to(device)
            loc_pred = model(loc, vel, nodes.detach())
            loss = loss_mse(loc_pred, loc_end)
        else:
            raise Exception("Wrong model")
        if args.time_exp:
            torch.cuda.synchronize()
            t2 = time.time()
            time_exp_dic['time'] += t2 - t1
            time_exp_dic['counter'] += 1

            print("Forward average time: %.6f" % (time_exp_dic['time'] / time_exp_dic['counter']))
            logging.info("Forward average time: %.6f" % (time_exp_dic['time'] / time_exp_dic['counter']))
        
        if backprop:
            loss.backward()
            optimizer.step()
    
        res['loss'] += loss.item() * batch_size
        res['counter'] += batch_size
        if batch_idx % args.log_interval == 0 and (args.model == "se3_transformer" or args.model == "tfn"):
            print('===> {} Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(loader.dataset.partition,
                epoch, batch_idx * batch_size, len(loader.dataset),
                100. * batch_idx / len(loader),
                loss.item()))

    if not backprop:
        prefix = "==> "
    else:
        prefix = ""
    print('%s epoch %d avg loss: %.5f' % (prefix+loader.dataset.partition, epoch, res['loss'] / res['counter']))

    return res['loss'] / res['counter']


if __name__ == "__main__":
    main()




