import torch
from torch import optim
from torch.nn import functional as F
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import save_image
from torchvision.transforms.functional import to_pil_image
from torchvision.transforms import ToTensor

import numpy as np
import argparse
import os

from PIL import Image

from utils import ExpDataset, reparameterize

from model import DisentangledVAE, CarlaDisentangledVAE,IthorDisentangledVAE \
    ,ViTDisentangledVAE,ResnetDisentangledVAE

import clip 
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

parser = argparse.ArgumentParser()
parser.add_argument('--data-dir', default='./', type=str, help='path to the data')
parser.add_argument('--data-tag', default='car', type=str, help='files with data_tag in name under data directory will be considered as collected states')
parser.add_argument('--num-splitted', default=10, type=int, help='number of files that the states from one domain are splitted into')
parser.add_argument('--batch-size', default=10, type=int)
parser.add_argument('--num-epochs', default=50, type=int)
parser.add_argument('--num-workers', default=4, type=int)
parser.add_argument('--learning-rate', default=0.0001, type=float)
parser.add_argument('--beta', default=10, type=float)
parser.add_argument('--save-freq', default=1000, type=int)
parser.add_argument('--bloss-coef', default=1, type=int)
parser.add_argument('--class-latent-size', default=8, type=int)
parser.add_argument('--content-latent-size', default=16, type=int)
parser.add_argument('--flatten-size', default=1024, type=int)
# parser.add_argument('--carla-model', default=False, action='store_true', help='CARLA or Carracing or ithor')
parser.add_argument('--model-name', default='carla', type=str, help='CARLA or Carracing or ithor')

args = parser.parse_args()

# Model = CarlaDisentangledVAE if args.carla_model else DisentangledVAE
if args.model_name == 'carracing':
    Model = DisentangledVAE
elif args.model_name == 'carla':
    Model = CarlaDisentangledVAE
elif args.model_name == 'ithor_cnn':
    Model = IthorDisentangledVAE
elif args.model_name == 'ithor_ViT':
    Model = ViTDisentangledVAE
elif args.model_name == 'ithor_resnet':
    Model = ResnetDisentangledVAE

# if 'da_data' in args.data_dir:
#     type = args.data_dir.split('data/da_data/')[1]
# else:
#     type = args.data_dir.split('data/')[1]

exp_name = 'domain_factor'

def updateloader(loader, dataset):
    dataset.loadnext()
    loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
    return loader


def vae_loss(x, mu, logsigma, recon_x, beta=1):
    recon_loss = F.mse_loss(x, recon_x, reduction='mean')
    kl_loss = -0.5 * torch.sum(1 + logsigma - mu.pow(2) - logsigma.exp())
    kl_loss = kl_loss / torch.numel(x)
    return recon_loss + kl_loss * beta


def forward_loss(x, model, beta):
    mu, logsigma, classcode = model.encoder(x)
    contentcode = reparameterize(mu, logsigma)
    shuffled_classcode = classcode[torch.randperm(classcode.shape[0])]

    # 2 domain specific info: shuffled_classcode, classcode
    latentcode1 = torch.cat([contentcode, shuffled_classcode], dim=1) 
    latentcode2 = torch.cat([contentcode, classcode], dim=1)

    recon_x1 = model.decoder(latentcode1)
    recon_x2 = model.decoder(latentcode2)

    return vae_loss(x, mu, logsigma, recon_x1, beta) + vae_loss(x, mu, logsigma, recon_x2, beta)


def backward_loss(x, model, device):
    # mu        -> domain-invariant feature
    # classcode -> domain-specific feature
    mu, logsigma, classcode = model.encoder(x)
    shuffled_classcode = classcode[torch.randperm(classcode.shape[0])]
    randcontent = torch.randn_like(mu).to(device)

    latentcode1 = torch.cat([randcontent, classcode], dim=1)
    latentcode2 = torch.cat([randcontent, shuffled_classcode], dim=1)

    recon_imgs1 = model.decoder(latentcode1).detach()
    recon_imgs2 = model.decoder(latentcode2).detach()

    cycle_mu1, cycle_logsigma1, cycle_classcode1 = model.encoder(recon_imgs1)
    cycle_mu2, cycle_logsigma2, cycle_classcode2 = model.encoder(recon_imgs2)

    cycle_contentcode1 = reparameterize(cycle_mu1, cycle_logsigma1)
    cycle_contentcode2 = reparameterize(cycle_mu2, cycle_logsigma2)

    bloss = F.l1_loss(cycle_contentcode1, cycle_contentcode2)
    return bloss

def main():
    # create directory

    image_path = f"checkimages/{exp_name}"
    checkpoint_path = f"./checkpoints/{exp_name}"
    if not os.path.exists(image_path):
        os.makedirs(image_path)
    if not os.path.exists(checkpoint_path):
        os.makedirs(checkpoint_path)



    # create dataset and loader
    transform = transforms.Compose([transforms.ToTensor()])
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')    

    dataset = ExpDataset(args.data_dir, args.data_tag, args.num_splitted, transform)
    loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)

    # create model
    model = Model(class_latent_size = args.class_latent_size, content_latent_size = args.content_latent_size)

    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)

    # do the training
    writer = SummaryWriter()
    batch_count = 0
    for i_epoch in range(args.num_epochs):
        for i_split in range(args.num_splitted):
            for i_batch, imgs in enumerate(loader):
                batch_count += 1
                # forward circle
                imgs = imgs.permute(1,0,2,3,4).to(device, non_blocking=True) # (batch, domain, c, h, w) -> (domain, batch, c, h, w)
                imgs = F.interpolate(imgs,(3,224,224))
                optimizer.zero_grad()

                floss = 0
                for i_class in range(imgs.shape[0]):
                    image = imgs[i_class]
                    floss += forward_loss(image, model, args.beta)
                floss = floss / imgs.shape[0]

                # backward circle
                imgs = imgs.reshape(-1, *imgs.shape[2:])
                bloss = backward_loss(imgs, model, device)

                (floss + bloss * args.bloss_coef).backward()
                optimizer.step()

                # write log
                writer.add_scalar('floss', floss.item(), batch_count)
                writer.add_scalar('bloss', bloss.item(), batch_count)

                # save image to check and save model 
                if i_batch % args.save_freq == 0:
                    print("%d Epochs, %d Splitted Data, %d Batches is Done." % (i_epoch, i_split, i_batch))
                    rand_idx = torch.randperm(imgs.shape[0])
                    imgs1 = imgs[rand_idx[:9]]
                    imgs2 = imgs[rand_idx[-9:]]
                    with torch.no_grad():
                        mu, _, classcode1 = model.encoder(imgs1)
                        _, _, classcode2 = model.encoder(imgs2)
                        recon_imgs1 = model.decoder(torch.cat([mu, classcode1], dim=1))
                        recon_combined = model.decoder(torch.cat([mu, classcode2], dim=1))

                    saved_imgs = torch.cat([imgs1, imgs2, recon_imgs1, recon_combined], dim=0)
                    save_image(saved_imgs, image_path + "/%d_%d_%d.png" % (i_epoch, i_split,i_batch), nrow=9)
            # load next splitted data
            updateloader(loader, dataset)
    torch.save(model.state_dict(), checkpoint_path + f"/model_{args.num_epochs}_{args.model_name}.pt")
    torch.save(model.encoder.state_dict(), checkpoint_path + f"/encoder_{args.num_epochs}_{args.model_name}.pt")

    writer.close()

if __name__ == '__main__':
    main()
