import numpy as np
import torch
from torch.optim import Adam, SGD
from torch.optim.lr_scheduler import MultiStepLR
from ema_pytorch import EMA
from tqdm.auto import tqdm

from gaussian_diffusion import GaussianDiffusion
from siren_model import ConditionalSirenNet
from silu_model import ConditionalSiLUNet
from wavediff_model import DiffWave

from utils import great_circle_distance_loss, great_circle_distance_numpy

from dataset import GeoDataset, PretrainedDataset
from torch.utils.data import TensorDataset, DataLoader

import os
import yaml
import json

from torch.utils.tensorboard import SummaryWriter

def initialize(config):
    if not os.path.isdir(config["checkpoint"]["checkpoint_save_path"] + config["global"]["experiment_name"]):
        os.makedirs(config["checkpoint"]["checkpoint_save_path"] + config["global"]["experiment_name"])

    if not os.path.isdir(config["sample"]["sample_save_path"] + config["global"]["experiment_name"]):
        os.makedirs(config["sample"]["sample_save_path"] + config["global"]["experiment_name"])

    # Global configurations
    if torch.cuda.is_available():
        print("Using GPU backend")
        device = torch.device("cuda:{}".format(config["global"]["gpu_id"]))
    else:
        print("Using CPU backend")
        device = torch.device("cpu")

    # Dataloader configuration
    """Notice! Always use a scaled dataset (i.e., scaling to [-1, 1])"""

    if config["train"]["train_mode"] == "raw_image":
        train_data = GeoDataset("datasets/YFCC/yfcc25600_places365.csv", "datasets/YFCC/yfcc4k_rgb_images", device)
        trainloader = DataLoader(train_data, batch_size=config["train"]["train_batch_size"], shuffle=True, pin_memory=True)

        valid_data = GeoDataset("datasets/Im2GPS/im2gps_places365.csv", "datasets/Im2GPS/im2gps_rgb_images", device)
        validloader = DataLoader(valid_data, batch_size=config["valid"]["valid_batch_size"], shuffle=False)

        test_data = GeoDataset("datasets/YFCC/yfcc25600_places365.csv", "datasets/YFCC/yfcc4k_rgb_images", device)
        testloader = DataLoader(test_data, batch_size=config["test"]["test_batch_size"], shuffle=False)

    elif config["train"]["train_mode"] == "pretrained":
        train_data = PretrainedDataset(config["dataset"]["train_dataset_path"] + config["dataset"]["train_dataset_name"], train=True, use_augmentation=config["train"]["use_augmentation"], img_perturb=config["train"]["train_img_perturb"], gps_purturb=config["train"]["train_gps_perturb"])
        trainloader = DataLoader(train_data, batch_size=config["train"]["train_batch_size"], shuffle=True)

        valid_data = PretrainedDataset(config["dataset"]["valid_dataset_path"] + config["dataset"]["valid_dataset_name"],train=False)
        validloader = DataLoader(valid_data, batch_size=config["valid"]["valid_batch_size"], shuffle=False)

        test_data = PretrainedDataset(config["dataset"]["test_dataset_path"] + config["dataset"]["test_dataset_name"],train=False)
        testloader = DataLoader(test_data, batch_size=config["test"]["test_batch_size"], shuffle=False)
    else:
        assert False, "Training mode not supported!"

    # Model configuration
    model = ConditionalSirenNet(
        dim_hidden=config["model"]["hidden_dimension"],
        dim_cond=config["model"]["condition_dimension"],
        dropout_rate=config["model"]["dropout_rate"],
        emb_dropout_rate=config["model"]["emb_dropout_rate"]
    )
    # model = DiffWave(
    #     residual_layers=6,
    #     residual_channels=config["model"]["hidden_dimension"],
    #     n_mels=config["model"]["condition_dimension"],
    #     dilation_cycle_length=5,
    #     n_timesteps=config["diffusion"]["train_timestep_number"]
    # )

    # Diffusion configuration
    diffusion = GaussianDiffusion(
        model,
        dim_inputs=config["diffusion"]["input_dimension"],
        location_encoding_type=config["diffusion"]["location_encoding_type"],
        dim_encoding=config["diffusion"]["location_encoding_dimension"],
        dim_hidden=config["model"]["hidden_dimension"],
        dim_condition=config["model"]["condition_dimension"],
        train_mode=config["train"]["train_mode"],
        data_size=len(train_data),
        num_train_grid_points=config["diffusion"]["num_grid_points"],
        num_sample_grid_points=config["diffusion"]["num_sample_grid_points"],
        train_grid_filepath=config["diffusion"]["train_grid_filepath"],
        sample_grid_filepath=config["diffusion"]["sample_grid_filepath"],
        device=device,
        objective=config["diffusion"]["objective"],
        timesteps = config["diffusion"]["train_timestep_number"],    # number of steps
        sample_save_interval= config["diffusion"]["sample_save_interval"],
        sampling_timesteps = config["diffusion"]["sample_timestep_number"],
        noise_amplifier=config["diffusion"]["noise_amplifier"],
    )

    model.to(device)
    diffusion.to(device)

    # pretrain_opt = SGD(diffusion.location_decoder.parameters(), lr=config["pretrain"]["pretrain_lr"], weight_decay=config["pretrain"]["pretrain_weight_decay"])
    pretrain_opt = None
    opt = Adam(diffusion.parameters(), lr=config["train"]["train_lr"], betas=config["train"]["adam_betas"], weight_decay=config["train"]["train_weight_decay"])
    scheduler = MultiStepLR(opt, milestones=config["train"]["train_milestones"], gamma=config["train"]["train_gamma"])
    ema = EMA(diffusion, beta=config["train"]["ema_decay"], update_every=config["train"]["ema_update_every"])
    ema.to(device)

    return device, model, diffusion, trainloader, validloader, testloader, pretrain_opt, opt, scheduler, ema

def save_checkpoint(diffusion_dict, opt_dict, loss, epoch, config):
    checkpoint_folder = config["checkpoint"]["checkpoint_save_path"] + config["global"]["experiment_name"]
    torch.save({
        'epoch': epoch,
        'model_state_dict': diffusion_dict,
        'optimizer_state_dict': opt_dict,
        'loss': loss,
    }, checkpoint_folder + "/chkpt-{}.pt".format(epoch))

    if (epoch + 1) == config["checkpoint"]["checkpoint_interval"]:
        json.dump(config, open(checkpoint_folder + "/config.json", "w"))

def evaluation_metrics(dists):
    dists = 6371 * dists

    print(f"Percentage of < 1: {np.sum(np.array(dists) < 1) / len(dists):.4f}")
    print(f"Percentage of < 25: {np.sum(np.array(dists) < 25) / len(dists):.4f}")
    print(f"Percentage of < 200: {np.sum(np.array(dists) < 200) / len(dists):.4f}")
    print(f"Percentage of < 750: {np.sum(np.array(dists) < 750) / len(dists):.4f}")
    print(f"Percentage of < 2500: {np.sum(np.array(dists) < 2500) / len(dists):.4f}")

def train(diffusion, trainloader, validloader, testloader, opt, scheduler, ema, config, device):
    total_loss = 0.
    total_loss1, total_loss2, total_loss3, total_loss4 = 0., 0., 0., 0.

    with torch.no_grad():
        for train_imgs, train_gps, train_idx in tqdm(trainloader):
            train_gps = train_gps.to(device)
            diffusion.location_encoder(train_gps, train_idx, preload=False)

        diffusion.location_encoder._set_coeff_scale()

    for epoch in range(config["train"]["train_num_epochs"]):

        print("Start training epoch {} with lr {}".format(epoch+1, scheduler.get_lr()))
        loss = None
        diffusion.model.train()

        dists = []

        for train_imgs, train_gps, train_idx in tqdm(trainloader):

            opt.zero_grad()

            train_imgs = train_imgs.to(device)
            train_gps = train_gps.to(device)

            loss1, loss2, loss3, loss4, loss1_weight, loss2_weight, loss3_weight, loss4_weight, model_out, target = diffusion(train_gps, train_idx, train_imgs, True)

            total_loss1 += loss1.mean().item()
            total_loss2 += loss2.mean().item()
            total_loss3 += loss3.mean().item()
            total_loss4 += loss4.mean().item()

            loss = loss1_weight.mean() + 10 * loss4_weight.mean()
            # loss = loss2_weight.mean()
            total_loss += loss.item()

            loss.backward()

            dists += loss2.detach().cpu().numpy().tolist()

            opt.step()
            ema.update()

        scheduler.step()

        print(f'Reverse KL loss: {total_loss1 / len(trainloader):.4f}, Great Circle loss: {total_loss2 / len(trainloader):.4f}, Embedding Cosine loss: {total_loss3 / len(trainloader):.4f}, Embedding MSE loss: {total_loss4 / len(trainloader):.4f}, Total loss: {total_loss / len(trainloader):.4f}')
        total_loss = 0.
        total_loss1, total_loss2, total_loss3, total_loss4 = 0., 0., 0., 0.

        # evaluation_metrics(np.array(dists).flatten())

        if (epoch + 1) % config["valid"]["valid_interval"] == 0:
            diffusion.model.eval()
            validate(diffusion, validloader, device)
            # sanity_check(diffusion, "valid", validloader, 1000, device)

        if (epoch + 1) % config["test"]["test_interval"] == 0:
            diffusion.model.eval()
            # test(diffusion, testloader, device)
            # sanity_check(diffusion, "test", testloader, 1000, device)

        if (epoch + 1) % config["sample"]["sample_interval"] == 0:
            diffusion.model.eval()

            # sample(diffusion, trainloader, "train", epoch, config, device)
            # sample(diffusion, validloader, "valid", epoch, config, device)
            sample(diffusion, testloader, "test", epoch, config, device)
            # sanity_check(diffusion, "sample", testloader, 1000, device)

        if (epoch + 1) % config["checkpoint"]["checkpoint_interval"] == 0:
            save_checkpoint(diffusion.state_dict(), opt.state_dict(), loss, epoch, config)

        # print(f"\n\n----------------------------------- End of Evaluation, Epoch {epoch} -----------------------------------\n\n")

def validate(diffusion, validloader, device):
    with torch.no_grad():
        print("\nStart evaluating on the validation set...\n")
        total_loss1, total_loss2, total_loss3, total_loss4 = 0., 0., 0., 0.
        dists = []

        for valid_imgs, valid_gps, valid_idx in validloader:

            valid_imgs = valid_imgs.to(device)
            valid_gps = valid_gps.to(device)

            loss1, loss2, loss3, loss4, loss1_weight, loss2_weight, loss3_weight, loss4_weight, model_out, target = diffusion(valid_gps, valid_idx, valid_imgs, False, last_step_only=True)

            total_loss1 += loss1.mean().item()
            total_loss2 += loss2.mean().item()
            total_loss3 += loss3.mean().item()
            total_loss4 += loss4.mean().item()

            dists += loss2.detach().cpu().numpy().tolist()

        print(f"Evaluation: Reverse KL loss: {total_loss1 / len(validloader):.4f}, Great Circle loss:  {total_loss2 / len(validloader):.4f}, Embedding Cosine Loss: {total_loss3 / len(validloader):.4f}, Embedding MSE loss:  {total_loss4 / len(validloader):.4f}\n")

        evaluation_metrics(np.array(dists).flatten())

def test(diffusion, testloader, device):
    with torch.no_grad():
        print("\nStart evaluating on the test set...\n")
        log_model_probs, log_target_probs = [], []
        dists, losses = [], []
        for test_imgs, test_gps, test_idx in testloader:
            test_imgs = test_imgs.to(device)
            test_gps = test_gps.to(device)

            loss1, loss2, loss3, loss4, loss1_weight, loss2_weight, loss3_weight, loss4_weight, model_out, target = diffusion(test_gps, test_idx, test_imgs, False, last_step_only=True)
            log_model_probs += model_out.detach().cpu().numpy().tolist()
            log_target_probs += target.detach().cpu().numpy().tolist()
            dists += loss2.detach().cpu().numpy().tolist()
            losses += loss1.detach().cpu().numpy().tolist()

        evaluation_metrics(np.array(dists).flatten())

def sample(diffusion, loader, eval, epoch, config, device):
    with torch.no_grad():
        print(f"\nStart conditional sampling on the {eval} set... \n")

        grd_gps_list, sampled_gps_list = [], []
        grd_gps_embedding_list, sampled_gps_embedding_list = [], []

        replica_size = config["sample"]["replica_size"]
        for i in range(replica_size):
            print("Sampling replica {} out of {}...".format(i+1, replica_size))
        #    cn = 0
            for imgs, gps, idx in tqdm(loader):
                imgs, gps = imgs.to(device), gps.to(device)

                grd_gps_embedding = diffusion.location_encoder(gps, idx, preload=False) ## ???? Is this correct ???

                sampled_gps_embeddings, sampled_gps = diffusion.sample(imgs, return_all_timesteps=True)

                grd_gps_list.append(gps.detach().cpu().numpy())
                sampled_gps_list.append(sampled_gps.detach().cpu().numpy())
                grd_gps_embedding_list.append(grd_gps_embedding.detach().cpu().numpy())
                sampled_gps_embedding_list.append(sampled_gps_embeddings.detach().cpu().numpy())


        grd_gps = np.concatenate(grd_gps_list, axis=0)
        sampled_gps = np.concatenate(sampled_gps_list, axis=0)

        mean_sampled_gps = np.mean(sampled_gps.reshape((replica_size, -1, 11, 2)), axis=0)
        median_sampled_gps = np.median(sampled_gps.reshape((replica_size, -1, 11, 2)), axis=0)

        grd_gps_embeddings = np.concatenate(grd_gps_embedding_list, axis=0)
        sampled_gps_embeddings = np.concatenate(sampled_gps_embedding_list, axis=0)

        print(grd_gps.shape, sampled_gps.shape, grd_gps_embeddings.shape, sampled_gps_embeddings.shape)

        print("Raw eval: ")
        dists = great_circle_distance_numpy(grd_gps, sampled_gps[:, -1, :])
        evaluation_metrics(np.array(dists).flatten())

        print("Mean eval: ")
        dists = great_circle_distance_numpy(grd_gps.reshape((replica_size, -1, 2))[0], mean_sampled_gps[:, -1, :])
        evaluation_metrics(np.array(dists).flatten())

        print("Median eval: ")
        dists = great_circle_distance_numpy(grd_gps.reshape((replica_size, -1, 2))[0], median_sampled_gps[:, -1, :])
        evaluation_metrics(np.array(dists).flatten())

        np.savez(config["sample"]["sample_save_path"] + config["global"]["experiment_name"] + "/conditional_{}_sampling_{}".format(eval, epoch), grd_gps=grd_gps, sampled_gps=sampled_gps
                 ,grd_gps_embedding=grd_gps_embeddings, sampled_gps_embedding=sampled_gps_embeddings)

if __name__ == "__main__":
    torch.multiprocessing.set_start_method('spawn')

    with open('config.yaml', 'r') as file:
        config = yaml.safe_load(file)

    print(config)

    device, model, diffusion, trainloader, validloader, testloader, pretrain_opt, opt, scheduler, ema = initialize(config)
    # pretrain(diffusion, pretrain_opt)
    train(diffusion, trainloader, validloader, testloader, opt, scheduler, ema, config, device)
