import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import random
import time
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
import PIL 
from tqdm import tqdm
import matplotlib.pyplot as plt
# tsne and pca
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import utils
from DeepTaxonNet import DeepTaxonNet
import argparse
import os
import sys
import wandb
import json

MODEL_SAVE_PATH_PREFIX = './models'

def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.cuda.manual_seed_all(seed)
    torch.cuda.manual_seed(seed)

#######################################################

parser = argparse.ArgumentParser(
    description="Train Deep TaxonNet using configuration from a JSON file."
)
parser.add_argument(
    '--config',
    type=str,
    required=True,
    help='Path to the JSON configuration file.'
)

# Parse the command line to get the config file path
initial_args = parser.parse_args()

# Load the actual configuration from the specified JSON file
try:
    args = utils.load_config_from_json(initial_args.config)
except (FileNotFoundError, ValueError) as e:
    parser.error(str(e)) # Display error through argparse system

# --- Configuration is loaded into 'args' ---
print("Configuration loaded successfully:")
print("-" * 30)
# Print all loaded args and their types
for key, value in sorted(vars(args).items()): # Sort for consistent output
    print(f"  {key}: {value} (Type: {type(value).__name__})")
print("-" * 30)


set_seed(args.seed)
device = torch.device(f"cuda:{args.device_id}" if torch.cuda.is_available() else "cpu")

#######################################################
# wandb init
if args.wandb:
    wandb.init(project=args.wandb_project, name=args.wandb_run_name)
    wandb.config.update(args)

#######################################################

# define data loader
print('Loading data...')
dataset = args.dataset
if args.use_contrastive_loss == False and (args.dataset == 'cifar-10' or args.dataset == 'cifar-100'):
    dataset = f'{args.dataset}-eval'
    train_loader, test_loader, train_data, test_data = utils.get_data_loader(dataset, args.batch_size, args.normalize)
else:
    train_loader, test_loader, train_data, test_data = utils.get_data_loader(dataset, args.batch_size, args.normalize)

# define model
model = DeepTaxonNet(
    n_layers=args.n_layers,
    input_dim=args.input_dim,
    enc_hidden_dim=args.enc_hidden_dim,
    dec_hidden_dim=args.dec_hidden_dim,
    latent_dim=args.latent_dim,
    encoder_name=args.encoder_name,
    decoder_name=args.decoder_name,
    kl1_weight=args.kl1_weight,
    recon_weight=args.recon_weight,
    dkl_margin=args.dkl_margin,
    dkl_weight_lambda=args.dkl_weight_lambda,
    convex_weight_lambda=args.convex_weight_lambda,
    logvar_init_range=args.logvar_init_range,
).to(device)


# define optimizer
optimizer = optim.AdamW(model.parameters(), lr=args.lr)
scheduler = None
if args.lr_scheduler == 'step':
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1)
elif args.lr_scheduler == 'linear-up':
    scheduler = optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.1, total_iters=args.epochs)
elif args.lr_scheduler == 'linear-down':
    scheduler = optim.lr_scheduler.LinearLR(optimizer, start_factor=0.1, end_factor=1.0, total_iters=args.epochs)
elif args.lr_scheduler == 'cosine':
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=1e-5)
    

print('Start training...')
steps = 0
epochs = args.epochs
for epoch in range(epochs):
    print(f'Epoch {epoch}')
    model.train()
    for j, batch in enumerate(train_loader):
        # model.train()
        optimizer.zero_grad()

        if args.use_contrastive_loss:
            x_aug_1, x_aug_2 = batch
            x = torch.cat((x_aug_1, x_aug_2), dim=0) # batch_size * 2, image shape
            x = x.to(device)

        else:
            x, _ = batch
            x = x.to(device)

        loss, recon_loss, kl1, kl2, _, pcx, _, _, z, pcx_contrast = model(x)  

        # do cosine similarity
        if args.use_contrastive_loss:
            z = F.normalize(z, p=2, dim=1)
            z_similarity = torch.matmul(z, z.T) # batch_size * 2, batch_size * 2
            z_sim_loss = utils.contrastive_loss(z_similarity, args.embed_temp) * args.contrastive_loss_weight

            # do cosine similarity
            pcx_contrast = F.normalize(pcx_contrast, p=2, dim=1)
            pcx_similarity = torch.matmul(pcx_contrast, pcx_contrast.T) # batch_size * 2, batch_size * 2
            pcx_sim_loss = utils.contrastive_loss(pcx_similarity, args.pcx_temp) * args.contrastive_loss_weight
            
            loss = loss + z_sim_loss + pcx_sim_loss

        if args.wandb:
            wandb.log({'loss': loss.item(),
                        'recon_loss': recon_loss.item(),
                        'kl1': -kl1.item(),
                        'kl2': -kl2.item(),
                        'z_sim_loss': z_sim_loss.item() if args.use_contrastive_loss else 0,
                        'pcx_sim_loss': pcx_sim_loss.item() if args.use_contrastive_loss else 0,
                        'learning_rate': optimizer.param_groups[0]['lr'],
                        'steps': steps})

        loss.backward()
        optimizer.step()
        steps += 1

    # for every epoch
    if scheduler is not None:
        scheduler.step()


    if epoch % args.model_save_interval == 0:
        model_save_path = f'{MODEL_SAVE_PATH_PREFIX}/{args.model_save_path}/'
        if not os.path.exists(model_save_path):
            os.makedirs(model_save_path)
        torch.save(model.state_dict(), f'{model_save_path}/deep_taxon_{epoch}.pt')
        print(f'Model saved at {model_save_path}/deep_taxon_{epoch}.pt')


        if args.dataset == 'cifar-10' or args.dataset == 'cifar-100':
            cifar_train_loader, cifar_test_loader, _, _ = utils.get_data_loader(f'{args.dataset}-eval', 256, False)
            annotation = utils.label_annotation(model, cifar_train_loader, args.n_classes, device)
            acc = utils.basic_node_evaluation(model, annotation, cifar_test_loader, device)
            nmi = utils.compute_nmi(model, annotation, cifar_test_loader, device)
            dendrogram_purity = utils.soft_dendrogram_purity(model, cifar_test_loader, device)
            overall_leaf_purity, per_leaf_purities = utils.leaf_purity(model, cifar_test_loader, device)
            if args.wandb:
                wandb.log({'Accuracy': acc, 
                            'NMI': nmi,
                            'Dendrogram Purity': dendrogram_purity,
                            'Leaf Purity': overall_leaf_purity, 
                           'epoch': epoch})
            
            if args.dataset == 'cifar-100':
                # also evaluate on cifar-20
                cifar20_train_loader, cifar20_test_loader, _, _ = utils.get_data_loader(f'cifar-20-eval', 256, False)
                annotation = utils.label_annotation(model, cifar20_train_loader, 20, device)
                acc = utils.basic_node_evaluation(model, annotation, cifar20_test_loader, device)
                nmi = utils.compute_nmi(model, annotation, cifar20_test_loader, device)
                dendrogram_purity = utils.soft_dendrogram_purity(model, cifar20_test_loader, device)
                overall_leaf_purity, per_leaf_purities = utils.leaf_purity(model, cifar20_test_loader, device)
                if args.wandb:
                    wandb.log({'CIFAR-20 Accuracy': acc, 
                                'CIFAR-20 NMI': nmi,
                                'CIFAR-20 Dendrogram Purity': dendrogram_purity,
                                'CIFAR-20 Leaf Purity': overall_leaf_purity,
                               'epoch': epoch})
            
        else:
            annotation = utils.label_annotation(model, train_loader, args.n_classes, device)
            acc = utils.basic_node_evaluation(model, annotation, test_loader, device)
            nmi = utils.compute_nmi(model, annotation, test_loader, device)
            dendrogram_purity = utils.soft_dendrogram_purity(model, test_loader, device)
            overall_leaf_purity, per_leaf_purities = utils.leaf_purity(model, test_loader, device)
            if args.wandb:
                wandb.log({'Accuracy': acc, 
                            'NMI': nmi,
                            'Dendrogram Purity': dendrogram_purity,
                            'Leaf Purity': overall_leaf_purity,
                           'epoch': epoch})