#%% config

import argparse
parser = argparse.ArgumentParser()

parser.add_argument('--dataset', type=str, default = 'dsprites', help='Dataset')
parser.add_argument('--encoder', type=str, default = 'burgess', help='encoder')
parser.add_argument('--condition', type=str, default = 'recomb2element', help='Specify Partitioning Training and Validation dataset')
parser.add_argument('--beta', type = float, default=1, help='Coefficient of Group Loss')
parser.add_argument('--n_blocks', type = int, default= None, help='Number of blocks of CycleGAN')
parser.add_argument('--save_folder', type = str, default='../data/sims/test', help='Folder to Save')
parser.add_argument('--group_ratio', type = int, nargs='+', default=None, help='Ratio of idnetity inverse assoc')
parser.add_argument('--latent_size', type = int, default=10, help='Dimension of Latent Variable')
parser.add_argument('--batch_size', type = int, default=64, help='Batch_size')
parser.add_argument('--lr_scheduling', type = bool, default=False, help='lr_scheduling')
parser.add_argument('--beta_hessian', type = float, default= None, help='beta_hessian')
parser.add_argument('--beta_z', type = float, default= 0, help='beta_hessian')
parser.add_argument('--beta_var', type = float, default= 0, help='beta_var')
parser.add_argument('--beta_ga', type = float, default= 0, help='beta_ga')
parser.add_argument('--epochs', type = int, default= 100, help='epochs')
parser.add_argument('--ratio', type = float, default= None, help='ratio')
parser.add_argument("--so_n_size", default = None, nargs='+', type=int)
parser.add_argument("--randomness", default = False, type=bool)
parser.add_argument("--shuffle_axis", default = False, type=bool)
parser.add_argument("--monoid", default = False, type=bool)


args = parser.parse_args()

import sys
class _config:
    #encoder_name = 'burgess' #['higgins', 'burgess', 'burgess_v2', 'mpcnn', 'mathieu', 'kim']
    encoder_name = args.encoder
    decoder_name = 'CycleGAN'
    decoder_params = {'n_blocks': 6}
    if args.n_blocks:
        decoder_params['n_blocks'] = args.n_blocks
    latent_size = args.latent_size

    dataset_type = args.dataset
    if args.dataset == 'dsprites':
        input_size = [1,64,64]
    elif args.dataset == 'shapes3d':
        input_size = [3,64,64]
    elif args.dataset == 'mpi3d':
        input_size = [3,64,64]
    elif args.dataset == 'celeba':
        input_size = [3,64,64]
    else:
        raise

    so_n_size = args.so_n_size #[2,3,3]
    if args.condition == "test":
        condvar = ['test', 0.001]
    elif args.condition == "entire":
        condvar = ['entire', None]
    elif args.condition == "part":
        condvar = ['entire', args.ratio]
    elif args.condition == "recomb2element":
        if args.dataset == 'dsprites':
            condvar = ['recomb2element', 'leave1out']
        elif args.dataset == 'shapes3d':
            condvar = ['recomb2element', 'leave1out']
    elif args.condition == "recomb2range":
        if args.dataset == 'dsprites':
            condvar = ['recomb2range', 'shape2tx']
        elif args.dataset == 'shapes3d':
            condvar = ['recomb2range', 'shape2ohue']
    elif args.condition == "extrp":
        if args.dataset == 'dsprites':
            condvar = ['extrp', 'blank_side']
        elif args.dataset == 'shapes3d':
            condvar = ['extrp', 'fhue_gt50']
    else:
        raise
    if args.dataset == 'celeba':
        condvar = [None, None]

    '''
        dsprites_combination = [['recomb2range', 'shape2tx'],
        ['recomb2element', 'leave1out'],
        ['extrp', 'blank_side']]

        shaped3d_combination = [['recomb2range', 'shape2ohue'],
        ['recomb2element', 'leave1out'],
        ['extrp', 'fhue_gt50']]

        mpi3d_combination = [['recomb2range', 'cyl2horz'],
        ['recomb2element', 'leave1out'],
        ['extrp', 'horz_gt20']] 
    '''
    group_ratio = args.group_ratio

    batch_size = args.batch_size
    learning_rate = 5e-4
    epochs = args.epochs
    beta = args.beta
    beta_hessian = args.beta_hessian
    beta_z = args.beta_z
    beta_var = args.beta_var
    beta_ga = args.beta_ga
    optimizer = 'adam'
    l2_norm = 0.00
    lr_scheduling= args.lr_scheduling

    metrics_config = [{'name': 'recons_nll', 'params': {'loss': 'bce'}, 'output': ['y_pred', 'y']}]
    group_loss = 'l1'
    image_loss = 'bce'
    save_folder= args.save_folder
    randomness = args.randomness

    shuffle_axis = args.shuffle_axis

    monoid = args.monoid
config = _config()


#%% 
import sys
import torch

if '../src' not in sys.path:
    sys.path.append('../src')
if '../scripts/' not in sys.path:
    sys.path.append('../scripts/')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
from configs.config import config_sl
config_structure = config_sl(_config)
config_structure.save(config.save_folder)
#%% Define model, dataloader, criterion, optimizer

from ingredients.training import ModelCheckpoint

from dataset.utils import init_loader,  ptr_to_datas,index_to_group_input, random_index_to_group_input, ShuffledDataset, load_dataset_pivot
from models.init_model import init_plgm, init_soplgm
from models.so_n_representation import SO_n_product, eye_like
from training.loss import latent_diff_loss, image_diff_loss, get_metrics_list, Variance_loss
from training.optimizer import init_optimizer,init_lr_scheduler
from training.handlers import Saver_concat
from training.hessian_penalty import hessian_penalty

if config.so_n_size:
    so_n_repre = SO_n_product(config.so_n_size, device = device)
#    so_n_representators = []
#    for dim in config.so_n_size:
#        so_n_representators.append(SO_n_Representation(dim = dim))
#    so_n_latent_size = sum([so.params for so in so_n_representators]) * 2


if not config.so_n_size:
    model = init_plgm(config.encoder_name, config.decoder_name, config.input_size, config.latent_size, decoder_params=config.decoder_params)
else:
    model = init_soplgm(config.encoder_name, config.decoder_name, config.input_size, config.latent_size, so_n_repre, decoder_params=config.decoder_params)

condition, variant = config.condvar

train_dataset, test_dataset, pivot1_dataset, pivot2_dataset, pivot3_dataset  = load_dataset_pivot(config.dataset_type, condition, variant, pivot = [1,2,3])
print(f'train_dataset: {len(train_dataset)}, test_dataset: {len(test_dataset)}, pivot_dataset: {len(pivot1_dataset)}')
pivot1_pairing_dataset = ShuffledDataset(pivot1_dataset, test_dataset)
pivot2_pairing_dataset = ShuffledDataset(pivot2_dataset, test_dataset)
pivot3_pairing_dataset = ShuffledDataset(pivot3_dataset, test_dataset)


train_dataloader = init_loader(train_dataset, batch_size = config.batch_size)
pivot1_pairing_dataloader = init_loader(pivot1_pairing_dataset, batch_size = config.batch_size)
pivot2_pairing_dataloader = init_loader(pivot2_pairing_dataset, batch_size = config.batch_size)
pivot3_pairing_dataloader = init_loader(pivot3_pairing_dataset, batch_size = config.batch_size)

criterion_latent = latent_diff_loss(config.group_loss)
criterion_image = image_diff_loss(config.image_loss)
criterion_var = Variance_loss()

optimizer = init_optimizer(config.optimizer, model.parameters(), lr=config.learning_rate, l2_norm= config.l2_norm)

if config.lr_scheduling:
    scheduler = init_lr_scheduler(optimizer, 'reduce-on-plateau', lr_decay = 0.1, patience=  10)

metrics_train = get_metrics_list(config.metrics_config)
metrics_pivot =  get_metrics_list(config.metrics_config)

def shuffle_batch(tensor):
    z_index = torch.randperm(tensor.shape[0])
    z_permuted = tensor[z_index].view(tensor.size())   
    return z_index, z_permuted
def shuffle_axis(x):
    indices = torch.argsort(torch.rand(*x.shape), dim=0)
    result = x[indices, torch.arange(x.shape[1]).unsqueeze(0)]
    return None, result
if config.shuffle_axis:
    shuffle_function = shuffle_axis
else:
    shuffle_function = shuffle_batch
#%% Engine
from ignite.engine import Events, Engine
def update_model(engine, batch):
    imgs, _ = batch
    model.to(device)
    imgs = imgs.to(device)

    if config.group_ratio:
        batch_concat_data_conat, batch_concat_data1, batch_concat_data2, ptrs = random_index_to_group_input(imgs, config.group_ratio)   
    else:
        batch_concat_data_conat, batch_concat_data1, batch_concat_data2, ptrs = index_to_group_input(imgs)   

    if not config.so_n_size:
        z = model.encoder(batch_concat_data_conat)
        recon = model.decode(z, batch_concat_data1)
    else:
        recon, z, cs_list, matrices_so_n = model(batch_concat_data_conat, batch_concat_data1)

    #Reconstruction Loss
    loss_recon = criterion_image(recon, batch_concat_data2)
    # Group Loss
    z_iden, z_inv_gh, z_inv_hg, z_assoc_gh, z_assoc_hk, z_assoc_gk = ptr_to_datas(z, ptrs)
    datac_iden, datac_inv_gh, datac_inv_hg, datac_assoc_gh, datac_assoc_hk, datac_assoc_gk = ptr_to_datas(batch_concat_data_conat, ptrs)
    data1_iden, data1_inv_gh, data1_inv_hg, data1_assoc_gh, data1_assoc_hk, data1_assoc_gk = ptr_to_datas(batch_concat_data1, ptrs)
    data2_iden, data2_inv_gh, data2_inv_hg, data2_assoc_gh, data2_assoc_hk, data2_assoc_gk = ptr_to_datas(batch_concat_data2, ptrs)

    loss_iden = criterion_latent(z_iden, torch.zeros_like(z_iden))
    if config.monoid:
        loss_inv = torch.tensor(0.)
    else:
        loss_inv = criterion_latent(z_inv_gh + z_inv_hg, torch.zeros_like(z_inv_gh))
    loss_assoc = criterion_latent(z_assoc_gh + z_assoc_hk, z_assoc_gk)
    loss_group = loss_iden + loss_inv + loss_assoc

    if config.so_n_size:
        parsed_matrices_so_n = ptr_to_datas(matrices_so_n, ptrs)
        loss_so_n_iden = 0; loss_so_n_inv = torch.tensor(0.).to(device); loss_so_n_assoc = 0;
        for i in range(len(matrices_so_n)):
            so_n_iden, so_n_inv_gh, so_n_inv_hg, so_n_assoc_gh, so_n_assoc_hk, so_n_assoc_gk = parsed_matrices_so_n[i]
            loss_so_n_iden += criterion_latent(so_n_iden, eye_like(so_n_iden))
            if not config.monoid:
                loss_so_n_inv += criterion_latent(torch.bmm(so_n_inv_gh.transpose(2,1), so_n_inv_hg), eye_like(so_n_inv_hg))
            loss_so_n_assoc += criterion_latent(torch.bmm( so_n_assoc_gk.transpose(2,1),torch.bmm(so_n_assoc_hk, so_n_assoc_gh)), eye_like(so_n_assoc_hk))
        loss_so_n_group = loss_so_n_iden + loss_so_n_inv +loss_so_n_assoc

    #Variance Loss and z Reconstruction Loss
    z_true = z[ptrs[1]:]
    loss_variance = criterion_var(z_true)
    _, z_permuted = shuffle_function(z_true)
    z_gt = z_permuted[:imgs.shape[0]]

    if not config.so_n_size:
        decoded = model.decode(z_gt.detach(),imgs)
        z_recon = model.encoder(torch.cat([imgs, decoded], dim = 1))
        loss_recon_z = criterion_latent(z_recon, z_gt) 
    else:
        matrices_permuted = [shuffle_batch(matrix)[1][:imgs.shape[0]].detach() for matrix in matrices_so_n]
        decoded = model.z_matrix_to_decoded(z_gt.detach(), matrices_permuted, imgs)
        z_recon, z_linear_recon, z_so_n_recon, _, matrix_list_recon = model.encode(torch.cat([imgs, decoded], dim = 1))
        loss_recon_z = criterion_latent(z_linear_recon, z_gt) 
        for matrix, matrix_recon in zip(matrices_permuted,matrix_list_recon):
            loss_recon_z_so_n = criterion_latent(torch.bmm(matrix.transpose(2,1), matrix_recon), eye_like(matrix_recon))
        loss_variance_so_n = sum([criterion_var(matrix) for matrix in matrices_permuted])
    
    #Group Action Loss

    if not config.so_n_size:
        if not config.randomness:
            loss_recon_iden = criterion_image(model.decode(torch.zeros_like(z_iden),data1_iden), data2_iden)
        else:
            ratio = torch.rand(z_iden.shape[0]).unsqueeze(1).to(device)
            loss_recon_iden = criterion_image(model.decode(z_iden * ratio,data1_iden), data2_iden)
        recon_hg = model.decode(-z_inv_gh, data1_inv_hg)
        recon_gh = model.decode(-z_inv_hg, data1_inv_gh)
        loss_recon_inv = criterion_image(recon_hg , data2_inv_hg) + criterion_image(recon_gh, data2_inv_gh)

        recon_gk = model.decode(z_assoc_gh+ z_assoc_hk, data1_assoc_gk)
        loss_recon_assoc = criterion_image(recon_gk, data2_assoc_gk)
        loss_recon_group = loss_recon_iden + loss_recon_inv + loss_recon_assoc
    else:
        idx_iden = 0; idx_inv_gh = 1; idx_inv_hg = 2; idx_assoc_gh = 3; idx_assoc_hk = 4; idx_assoc_gk = 5;
        if not config.randomness:
            loss_recon_iden = criterion_image(model.z_matrix_to_decoded(torch.zeros_like(z_iden), [eye_like(parsed[idx_iden]) for parsed in parsed_matrices_so_n], data1_iden), data2_iden)
        else:
            ratio = torch.rand(z_iden.shape[0]).unsqueeze(1).to(device)
            import numpy as np
            index  = np.random.choice(2, len(parsed_matrices_so_n), p  = [0.5, 0.5])
            matrix_list = [parsed[idx_iden] if i==0 else eye_like(parsed[idx_iden]) for i, parsed in zip(index, parsed_matrices_so_n) ]
            loss_recon_iden = criterion_image(model.z_matrix_to_decoded(torch.zeros_like(z_iden)* ratio, matrix_list, data1_iden), data2_iden)
            recon_hg = model.z_matrix_to_decoded(-z_inv_gh, [parsed[idx_inv_gh].transpose(2,1) for parsed in parsed_matrices_so_n], data1_inv_hg)
            recon_gh = model.z_matrix_to_decoded(-z_inv_hg, [parsed[idx_inv_hg].transpose(2,1) for parsed in parsed_matrices_so_n], data1_inv_gh)
            if config.monoid:
                loss_recon_inv = torch.tensor(0.)
            else:
                loss_recon_inv = criterion_image(recon_hg , data2_inv_hg) + criterion_image(recon_gh, data2_inv_gh)
            recon_gk = model.z_matrix_to_decoded(z_assoc_gh+ z_assoc_hk, [torch.bmm(parsed[idx_assoc_hk], parsed[idx_assoc_gh]) for parsed in parsed_matrices_so_n] ,data1_assoc_gk)
            loss_recon_assoc = criterion_image(recon_gk, data2_assoc_gk)
            loss_recon_group = loss_recon_iden + loss_recon_inv + loss_recon_assoc


    #Loss Sum
    loss = loss_recon + loss_recon_group * config.beta_ga + config.beta * loss_group  + config.beta_z * loss_recon_z + config.beta_var * loss_variance
    if config.so_n_size:
        loss += loss_so_n_group * config.beta + loss_recon_z_so_n * config.beta_z + config.beta_var * loss_variance_so_n

    if config.beta_hessian:
        loss_hessian = hessian_penalty(model.decode, z, img = batch_concat_data1)
        loss += loss_hessian * config.beta_hessian
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if config.lr_scheduling:
        scheduler.step(loss)
    return_dict = {'y_pred': recon, 
            'y': batch_concat_data2,
            'template': batch_concat_data1,
            'recon_hg': recon_hg,
            'data2_inv_hg': data2_inv_hg,
            'data1_inv_hg': data1_inv_hg,
            'recon_gh': recon_gh,
            'data2_inv_gh': data2_inv_gh,
            'data1_inv_gh': data1_inv_gh,
            'recon_gk': recon_gk,
            'data1_assoc_gk': data1_assoc_gk,
            'data2_assoc_gk': data2_assoc_gk,

            'z': z,
            'ptrs': ptrs,
            'loss': loss.item(),
            'loss_recon_group': loss_recon_group.item(),
            'loss_recon_iden': loss_recon_iden.item(),
            'loss_recon_inv': loss_recon_inv.item(),
            'loss_recon_assoc': loss_recon_assoc.item(),
            'loss_variance': loss_variance.item(),
            'loss_recon': loss_recon.item(),
            'loss_z': loss_recon_z.item(),
            'loss_group': loss_group.item(),
            'loss_iden': loss_iden.item(),
            'loss_inv': loss_inv.item(),
            'loss_assoc': loss_assoc.item()}
    if config.so_n_size:
        return_dict['loss_so_n_group'] = loss_so_n_group.item()
        return_dict['loss_so_n_iden'] = loss_so_n_iden.item()
        return_dict['loss_so_n_inv'] = loss_so_n_inv.item()
        return_dict['loss_so_n_assoc'] = loss_so_n_assoc.item()
        return_dict['loss_recon_z_so_n'] = loss_recon_z_so_n.item()
        return_dict['loss_variance_so_n'] =loss_variance_so_n.item()
    return return_dict


def evalutate_model_pairing(engine, batch):
    model.eval()
    imgs_train, _, imgs_test, _ = batch
    model.to(device)
    imgs_train = imgs_train.to(device)
    imgs_test = imgs_test.to(device)
    imgs_concat = torch.cat([imgs_train, imgs_test], dim = 1)

    with torch.no_grad():
        if not config.so_n_size:
            z = model.encoder(imgs_concat)
            recon = model.decode(z, imgs_train)
            loss_recon = criterion_image(recon, imgs_test)
        else:
            recon, z, cs_list, matrices_so_n = model(imgs_concat, imgs_train)
            loss_recon = criterion_image(recon, imgs_test)

    return_dict = {'y_pred': recon, 
            'y': imgs_test,
            'template': imgs_train,
            'z': z,
            'loss_recon': loss_recon.item()}
    if config.so_n_size:
        return_dict['matrices_so_n'] = matrices_so_n
        return_dict['cs_list'] = cs_list
    return return_dict

trainer = Engine(update_model)
validator_pivot1 = Engine(evalutate_model_pairing)
validator_pivot2 = Engine(evalutate_model_pairing)
validator_pivot3 = Engine(evalutate_model_pairing)

#%% Handlers

for name, metric in metrics_train.items():
    metric.attach(trainer, name)   
for name, metric in metrics_pivot.items():
    metric.attach(validator_pivot1, name)   
for name, metric in metrics_pivot.items():
    metric.attach(validator_pivot2, name)   
for name, metric in metrics_pivot.items():
    metric.attach(validator_pivot3, name)   

output_list = ['loss','loss_recon','loss_group', 'loss_iden','loss_inv', 'loss_assoc', 'loss_z', 'loss_variance', 'loss_recon_iden', 'loss_recon_inv', 'loss_recon_assoc']
if config.so_n_size:
    output_list += ['loss_so_n_group', 'loss_so_n_iden', 'loss_so_n_inv', 'loss_so_n_assoc', 'loss_recon_z_so_n', 'loss_variance_so_n']
metric_saver_train = Saver_concat(save_folder = config.save_folder, output = output_list, name ='train', train = True).attach(trainer)
metric_saver_pivot = Saver_concat(save_folder = config.save_folder, output = ['loss_recon'], name ='pivot1').attach(validator_pivot1)
metric_saver_pivot = Saver_concat(save_folder = config.save_folder, output = ['loss_recon'], name ='pivot2').attach(validator_pivot2)
metric_saver_pivot = Saver_concat(save_folder = config.save_folder, output = ['loss_recon'], name ='pivot3').attach(validator_pivot3)

@trainer.on(Events.EPOCH_COMPLETED)
def validate(engine):
    validator_pivot1.run(pivot1_pairing_dataloader)
    validator_pivot2.run(pivot2_pairing_dataloader)
    validator_pivot3.run(pivot3_pairing_dataloader)



# Attach model checkpoint
def score_fn(engine):
    return -engine.state.metrics[list(metrics_pivot)[0]]

best_checkpoint = ModelCheckpoint(
    dirname=config.save_folder,
    filename_prefix='diff_best_nll',
    score_function=score_fn,
    create_dir=True,
    require_empty=False,
    save_as_state_dict=True
)
validator_pivot3.add_event_handler(Events.COMPLETED, best_checkpoint,
                            {'model': model})

# Save every 10 epochs
periodic_checkpoint = ModelCheckpoint(
    dirname= config.save_folder,
    filename_prefix='diff_interval',
    n_saved=config.epochs//10,
    create_dir=True,
    require_empty=False,
    save_as_state_dict=True
)
trainer.add_event_handler(Events.EPOCH_COMPLETED(every=10),
                            periodic_checkpoint, {'model': model})

#%% Run
trainer.run(train_dataloader, max_epochs = config.epochs)