import os
import argparse
import json
import numpy as np
import torch
from sklearn.preprocessing import MinMaxScaler
import pickle
from tqdm import tqdm

import functools
import torch.optim as optim
from torch.optim import Adam, AdamW
from torch_ema import ExponentialMovingAverage
from torch.utils.data import Dataset, DataLoader

import pandas as pd

from NCSBAD.utils import marginal_prob_std, loss_fn_norm
from NCSBAD.models.mlp_models import MLPDiffusion, MLPDiffusionPara, MLPDiffusionICL, MLPDiffusionVAE, MLPDiffusionTabSyn, MLPDiffusionBig512
from NCSBAD.models.mlp_models import MLPDiffusionTabSyn1024, MLPDiffusionVAE1024
from NCSBAD.models.mlp_models import MLP2048
from NCSBAD.models.scorewave import SCOREWAVENET
from NCSBAD.models.ddpm import ResNetDiffusion
import matplotlib.pyplot as plt

device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')


def plot_greyscale_images(first_image, test_image, score, cols=2):
      
    summed_array = score
    summed_array = summed_array.squeeze(0)
    summed_array = summed_array.squeeze(0)
    summed_array = summed_array.reshape((28, 28))
    # Create a figure and axis for three subplots (1 row, 3 columns)
    image_diff = test_image - first_image
    fig, axes = plt.subplots(1, 4, figsize=(12, 4))

    # Plot each image
    axes[0].imshow(first_image, cmap='gray')
    axes[0].set_title('Training Image')

    axes[1].imshow(test_image, cmap='gray')
    axes[1].set_title('Test Image')

    # Plot the difference
    axes[2].imshow(image_diff, cmap='gray')
    axes[2].set_title('Difference (Test Image - Training Image)')

    axes[3].imshow(summed_array, cmap='gray')
    axes[3].set_title('Score as Image')    

    # Adjust layout for better spacing
    plt.tight_layout()
    plt.savefig('./mnist/Interpretability.png')
    plt.show()


def train():
    # Get shared output_directory ready
    output_directory = './mnist'
    output_directory_ckp = os.path.join(output_directory, "ckp")
    print(output_directory)
    if not os.path.isdir(output_directory):
        os.makedirs(output_directory)
    if not os.path.isdir(output_directory_ckp):
        os.makedirs(output_directory_ckp)
        os.chmod(output_directory, 0o775)
    print("output directory", output_directory, flush=True)

    first_image = np.load('./mnist/test1.npy')
    scaled_array = first_image
    scaled_array = torch.from_numpy(scaled_array)
    scaled_array = scaled_array.view(-1, 784).unsqueeze(1)
    scaled_array = scaled_array.to(torch.float32)
    mean = scaled_array.mean()
    std_dev = scaled_array.std()

    scaled_array = (scaled_array - mean) / std_dev
    print(scaled_array.shape)

    batch_size = 3000
    batch = scaled_array.unsqueeze(0).repeat(batch_size, 1, 1, 1)

    train_loader = DataLoader(batch, batch_size=16, shuffle=False)

    training_data_pkl = np.zeros((64, 1, 784))

    model = 22
    
    # predefine model
    if model == 0:
        model_config ={}
        model_config["in_channels"] = 1
        model_config["out_channels"] = 1
        model_config["num_res_layers"] = 3
        model_config["res_channels"] = 32
        model_config["skip_channels"] = 32
        model_config["dilation_cycle"] = 3
        model_config["diffusion_step_embed_dim_in"] = 16
        model_config["diffusion_step_embed_dim_mid"] = 64
        model_config["diffusion_step_embed_dim_out"] = 64   
        net = SCOREWAVENET(**model_config).to(device)
        name = "SCOREWAVENET" 
        optimizer = AdamW(net.parameters(), lr=1e-4, weight_decay=0.001)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[600,800], gamma=0.1, verbose = False)
        ema = ExponentialMovingAverage(net.parameters(), decay=0.9999)
        num_epochs = 200
    elif model == 1:
        model_config ={}
        model_config["in_channels"] = 1
        model_config["out_channels"] = 1
        model_config["num_res_layers"] = 36
        model_config["res_channels"] = 256
        model_config["skip_channels"] = 256
        model_config["dilation_cycle"] = 12
        model_config["diffusion_step_embed_dim_in"] = 128
        model_config["diffusion_step_embed_dim_mid"] = 512
        model_config["diffusion_step_embed_dim_out"] = 512   
        net = SCOREWAVENET(**model_config).to(device)
        name = "SCOREWAVENET" 
        optimizer = AdamW(net.parameters(), lr=1e-4, weight_decay=0.001)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[600,800], gamma=0.1, verbose = False)
        ema = ExponentialMovingAverage(net.parameters(), decay=0.9999)
        num_epochs = 200
    elif model == 3:
        net = MLPDiffusion(d_in=training_data_pkl.shape[2]).to(device) 
        name = "MLPDiffusion"
        optimizer = AdamW(net.parameters(), lr=1e-4, weight_decay=0.001)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[600,800], gamma=0.1, verbose = False)
        ema = ExponentialMovingAverage(net.parameters(), decay=0.9999)
        num_epochs = 200
    elif model == 6:
        params = {}
        params["d_main"] = 128
        params["n_blocks"] = 3
        params["d_hidden"] = 512
        params["dropout_first"] = 0
        params["dropout_second"] = 0
        net = ResNetDiffusion(params, d_in=training_data_pkl.shape[2]).to(device)
        name = "ResNet"
        optimizer = AdamW(net.parameters(), lr=1e-4, weight_decay=0.001)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[600,800], gamma=0.1, verbose = False)
        ema = ExponentialMovingAverage(net.parameters(), decay=0.9999)
        num_epochs = 200 
    elif model == 9:
        net = MLPDiffusionPara(d_in=training_data_pkl.shape[2]).to(device) 
        name = "MLPDiffusionPara"
        optimizer = AdamW(net.parameters(), lr=1e-4, weight_decay=0.001)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[600,800], gamma=0.1, verbose = False)
        ema = ExponentialMovingAverage(net.parameters(), decay=0.9999)
        num_epochs = 200
    elif model == 10:
        net = MLPDiffusionICL(d_in=training_data_pkl.shape[2]).to(device) 
        name = "MLPDiffusionICL"
        optimizer = AdamW(net.parameters(), lr=1e-4, weight_decay=0.001)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[600,800], gamma=0.1, verbose = False)
        ema = ExponentialMovingAverage(net.parameters(), decay=0.9999)
        num_epochs = 200
    elif model == 11:
        net = MLPDiffusionVAE(d_in=training_data_pkl.shape[2], dim_t=4*training_data_pkl.shape[2]).to(device) 
        name = "MLPDiffusionVAE"
        optimizer = AdamW(net.parameters(), lr=1e-4, weight_decay=0.001)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[600,800], gamma=0.1, verbose = False)
        ema = ExponentialMovingAverage(net.parameters(), decay=0.9999)
        num_epochs = 200
    elif model == 14:
        net = MLPDiffusionTabSyn(d_in=training_data_pkl.shape[2]).to(device) 
        name = "MLPDiffusionTabSyn"
        optimizer = AdamW(net.parameters(), lr=1e-4, weight_decay=0.001)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[600,800], gamma=0.1, verbose = False)
        ema = ExponentialMovingAverage(net.parameters(), decay=0.9999)
        num_epochs = 200
    elif model == 15:
        net = MLPDiffusionBig512(d_in=training_data_pkl.shape[2]).to(device) 
        name = "MLPDiffusionBig512"
        optimizer = AdamW(net.parameters(), lr=1e-4, weight_decay=0.001)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[600,800], gamma=0.1, verbose = False)
        ema = ExponentialMovingAverage(net.parameters(), decay=0.9999)
        num_epochs = 200
    elif model == 18:
        net = MLPDiffusionTabSyn1024(d_in=training_data_pkl.shape[2]).to(device) 
        name = "MLPDiffusionTabsyn1024"
        optimizer = AdamW(net.parameters(), lr=1e-4, weight_decay=0.001)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[600,800], gamma=0.1, verbose = False)
        ema = ExponentialMovingAverage(net.parameters(), decay=0.9999)
        num_epochs = 200
    elif model == 21:
        net = MLPDiffusionVAE1024(d_in=training_data_pkl.shape[2]).to(device) 
        name = "MLPDiffusionVAE1024"
        optimizer = AdamW(net.parameters(), lr=1e-4, weight_decay=0.001)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[600,800], gamma=0.1, verbose = False)
        ema = ExponentialMovingAverage(net.parameters(), decay=0.9999)
        num_epochs = 200
    elif model == 22:
        net = MLP2048(d_in=training_data_pkl.shape[2]).to(device) 
        name = "MLP2048"
        optimizer = AdamW(net.parameters(), lr=1e-4, weight_decay=0.001)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[120,160], gamma=0.1, verbose = False)
        ema = ExponentialMovingAverage(net.parameters(), decay=0.9999)
        num_epochs = 200
    else:
        print('Model chosen not available.')

    print(net)
    num_epochs = 200
    marginal_prob_std_fn = functools.partial(marginal_prob_std, sigma = 0.01, device = device)

    epoch = 0
    
    for epoch in tqdm(range(num_epochs+1)): 
        net.train()
        total_loss = 0.
        num_items = 0
        for i, X in enumerate(tqdm(train_loader)): 
            x = X.view(-1, 784).unsqueeze(1).to(device)    
            loss = loss_fn_norm(net, x, marginal_prob_std_fn)
            optimizer.zero_grad()
            loss.backward()    
            optimizer.step()
            ema.update()
            total_loss += loss.item() * x.shape[0]
            num_items += x.shape[0]

        epoch += 1
        scheduler.step()
        
        avg_loss = total_loss / num_items

        print(f"Epoch: {epoch}")
            
        print('-> Average Loss: {:.2f}'.format(avg_loss))

        
    checkpoint_name = 'best.pkl'
    torch.save({'model_state_dict': net.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'ckp_epoch': epoch},
            os.path.join(output_directory, "ckp", checkpoint_name))
    print('model at iteration %s is saved' % epoch)                      
     
    return net, name


def inference(fin, net):
    local_path = "mnistc_inter" 
    output_directory = './mnist'
    output_directory = os.path.join(output_directory, "ckp")

    
    first_image = np.load('./mnist/test1.npy')
    
    scaled_array = first_image
    scaled_array = torch.from_numpy(scaled_array)
    scaled_array = scaled_array.view(-1, 784).unsqueeze(1)
    scaled_array = scaled_array.to(torch.float32)
    mean = scaled_array.mean()
    std_dev = scaled_array.std()
    
    test_image = np.load('./mnist/test2.npy')
    
    scaled_array = test_image
    scaled_array = torch.from_numpy(scaled_array)
    scaled_array = scaled_array.view(-1, 784).unsqueeze(1)
    scaled_array = scaled_array.to(torch.float32)

    # Standardize the tensor
    scaled_array = (scaled_array - mean) / std_dev
    # scaled_array = scaled_array.permute(2, 0, 1)
    print(scaled_array.shape)

    test_loader = DataLoader(scaled_array, batch_size=1, shuffle=False)

    if fin == 1:
        model_path = os.path.join(output_directory, "best.pkl")
        checkpoint = torch.load(model_path, map_location='cpu')
        optimizer = torch.optim.Adam(net.parameters(), lr=0.0004)

        print(checkpoint['ckp_epoch'])
        # feed model dict and optimizer state
        net.load_state_dict(checkpoint['model_state_dict'])
        if 'optimizer_state_dict' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

        print('Successfully loaded model at iteration best')

    marginal_prob_std_fn = functools.partial(marginal_prob_std, sigma = 0.01, device = device)

    batch_scores = None

    net.eval()

    for i, x in enumerate(tqdm(test_loader)):
        print(x.shape)
        x = x.to(device)
        sample_length = x.size(2)

        sample_batch_size = x.shape[0]
        t = torch.ones(sample_batch_size, device=device) * 1e-3

        all_scores = None

        scores = 0.
        with torch.no_grad():

            num_iterations = 10
            extended_batch = x.unsqueeze(0).repeat(num_iterations, 1, 1, 1).to(device)  # Shape: [100, 50, 1, 6]
            extended_t = t.unsqueeze(0).repeat(num_iterations, 1)  # Shape: [100, 50]

            std = marginal_prob_std_fn(extended_t)[:, :, None, None]
            n = torch.randn_like(extended_batch, device=device) #* std
            z = extended_batch + n * std

            # Flatten the batch and time dimensions to match the network input requirements
            z_flat = z.view(-1, z.shape[-2], z.shape[-1])  # Shape: [100*50, 1, 6]
            t_flat = extended_t.view(-1)  # Shape: [100*50]

            # Pass through the network
            scores = net(z_flat, t_flat)

            # Reshape back to [100, 50, 1, 6]
            scores = scores.view(num_iterations, sample_batch_size, 1, sample_length)

            # Perform required operations
            scores = scores + n

            scores = scores.permute(0, 1, 3, 2)  # Shape: [100, 50, 6, 1]
            scores = scores.unsqueeze(2)  # Shape: [100, 50, 1, 6, 1]
            # scores = scores 
            scores = (scores**2).mean(dim=2, keepdim=True)  # Shape: [100, 50, 1, 6, 1]

            # Sum the results across the 100 iterations and average
            scores = scores.sum(dim=0) #/ num_iterations  # Shape: [50, 1, 6, 1]
 

        all_scores = torch.cat((all_scores, scores), dim = 0) if all_scores != None else scores
        batch_scores = torch.cat((batch_scores, all_scores)) if batch_scores != None else all_scores

    preds = batch_scores.cpu().detach().sum(1, keepdim = True).numpy()
    # print(preds)

    plot_greyscale_images(first_image, test_image, preds)

    





if __name__ == "__main__":
   
    net, name = train()

    fin = 1
    inference(fin, net)
