import os
import torch
import h5py
import torchvision.transforms as transforms
from torchvision.utils import save_image


class ModelLoader:
    def __init__(self, dataset_name, model_name, device, step):
        self.dataset_name = dataset_name
        self.model_name = model_name
        self.device = device
        self.step = step  # Default step, can be overridden

    def set_paths(self, epochs, critic_path='critic.pth', generator_path='generator.pth', lists_path='loss_lists.h5', info_path='info.txt', data_path='generated_data.h5', generator_name = 'GAN'):
        self.main_folder = f'{generator_name}/{self.dataset_name}_{self.model_name}/'
        self.epochs_folder = f'{self.main_folder}{epochs}epochs/'
        self.info_path = f'{self.main_folder}{info_path}'
        self.gen_data_path = f'{self.main_folder}{data_path}'
        self.generator_path = f'{self.epochs_folder}{generator_path}'
        self.critic_path = f'{self.epochs_folder}{critic_path}'
        self.list_path = f'{self.epochs_folder}{lists_path}'

    def _tensor_to_numpy(self, data):
        """ Convert tensor to numpy array, handling both CPU and CUDA tensors """
        if hasattr(data, 'cpu'):  # PyTorch tensor
            return data.cpu().detach().numpy()
        elif hasattr(data, 'numpy'):  # Already numpy or can convert
            return data.numpy()
        else:
            return data

    def load_model(self, epochs, critic, generator):
        if critic is not None:
            critic.load_state_dict(torch.load(self.critic_path, map_location=self.device))
            print(f'Critic loaded from {self.critic_path}.')
        generator.load_state_dict(torch.load(self.generator_path, map_location=self.device))
        print(f'Generator loaded from {self.generator_path}.')
        with h5py.File(self.list_path, 'r') as f:
            if critic is not None:
                loss_C_list = f['loss_C'][:]
            loss_G_list = f['loss_G'][:]
            generator_grad_norm_list = f['generator_grad_norm'][:]

        epochs_to_plot = list(range(self.step, epochs + 1, self.step))
        gen_data_list = []
        with h5py.File(self.gen_data_path, 'r') as f:
            for epoch in epochs_to_plot:
                gen_data_list.append(f[str(epoch)][:])
        print(f'Loss and gradient norms loaded from {self.list_path}.')

        if critic is not None:
            return loss_C_list, loss_G_list, generator_grad_norm_list, gen_data_list
        else:
            return loss_G_list, generator_grad_norm_list, gen_data_list

    def save_model(self, gan, epochs, batch_size, n_critic = None, lambda_gp = None, clip_value = None, info=False):
        if info:
            os.makedirs(os.path.dirname(self.info_path), exist_ok=True)
            with open(self.info_path, 'a') as f:
                f.write(f'Experiment for {epochs} epochs:\n')
                f.write(f'Batch size: {batch_size}\n')
                if n_critic is not None:
                    f.write(f'Critic updates: {n_critic}\n')
                if lambda_gp is not None:
                    f.write(f'Lambda GP: {gan.lambda_gp}\n')
                if clip_value is not None:
                    f.write(f'Clip value: {gan.clip_value}\n')
                f.write(f'Generator Optimizer:\n{str(gan.optimizer_G)}\n')
                optC = getattr(gan, 'optimizer_C', None)
                if optC is not None:
                    f.write(f'Critic Optimizer:\n{str(gan.optimizer_C)}\n')
        gan_c = getattr(gan, 'C', None)
        if gan_c is not None:
            os.makedirs(os.path.dirname(self.critic_path), exist_ok=True)
            torch.save(gan.C.state_dict(), self.critic_path)
            print(f'Model saved to {self.critic_path}.')
        # Save the Generator
        os.makedirs(os.path.dirname(self.generator_path), exist_ok=True)
        torch.save(gan.G.state_dict(), self.generator_path)
        print(f'Model saved to {self.generator_path}.')
        # Save training loss and gradient norms

        loss_C_list = getattr(gan, 'loss_C_list', None)
        if loss_C_list is not None and gan.loss_G_list is not None and gan.generator_grad_norm_list is not None:
            with h5py.File(self.list_path, 'w') as f:
                f.create_dataset('loss_C', data=self._tensor_to_numpy(gan.loss_C_list))
                f.create_dataset('loss_G', data=self._tensor_to_numpy(gan.loss_G_list))
                f.create_dataset('generator_grad_norm', data=self._tensor_to_numpy(gan.generator_grad_norm_list))
        elif gan.loss_G_list is not None and gan.generator_grad_norm_list is not None:
            with h5py.File(self.list_path, 'w') as f:
                f.create_dataset('loss_G', data=self._tensor_to_numpy(gan.loss_G_list))
                f.create_dataset('generator_grad_norm', data=self._tensor_to_numpy(gan.generator_grad_norm_list))

        if gan.gen_data is not None and gan.gen_data_list is not None:
            if epochs == 0:
                with h5py.File(self.gen_data_path, 'w') as f:
                    f.create_dataset(str(epochs), data=self._tensor_to_numpy(gan.gen_data))
            else:
                with h5py.File(self.gen_data_path, 'a') as f:
                    if str(epochs) in f:
                        del f[str(epochs)]
                    f.create_dataset(str(epochs), data=self._tensor_to_numpy(gan.gen_data_list[-1]))

    
    #Gaussian data
    def generated_distribution(self, generator, latent_dim, device, num_samples=1000, save=False, name=''):
        data_path = f'{self.epochs_folder}generated_data{name}.h5'

        if save:
            if os.path.exists(data_path):
                with h5py.File(data_path, 'r') as f:
                    generated_data = f['generated_data'][:]
                print(f'Generated data loaded from {data_path}.')
                return generated_data.flatten()

        generator.eval()
        with torch.no_grad():
            z = torch.randn(num_samples, latent_dim).to(device)
            generated_data = generator(z).cpu().numpy().flatten()

        if save:
            os.makedirs(os.path.dirname(data_path), exist_ok=True)
            with h5py.File(data_path, 'w') as f:
                f.create_dataset('generated_data', data=generated_data)
            print(f'Generated data saved to {data_path}.')
        return generated_data
        

    def generated_images(self, generator, latent_dim, device, epoch, num_samples=50, save=False, save_img = True, name = '', nrow=5):
        data_path = f'{self.epochs_folder}generated_data{name}.h5'
        img_path = f'{self.main_folder}/images/generated_image_{epoch}epochs.png'

        if save:
            if os.path.exists(data_path):
                with h5py.File(data_path, 'r') as f:
                    generated_data = f['generated_data'][:]
                print(f'Generated data loaded from {data_path}.')
                return generated_data

        if  self.dataset_name=='MNIST':
            z = torch.randn(num_samples, latent_dim, device=device)
            with torch.no_grad():
                gen_imgs = generator(z).to(device)
        elif  self.dataset_name=='CELEBA':
            z = torch.randn(num_samples, latent_dim, 1, 1, device=device)
            with torch.no_grad():
                gen_imgs = generator(z).to(device)
                gen_imgs = gen_imgs.reshape(gen_imgs.size(0), 3, 64, 64)
        elif self.dataset_name=='8gaussians':
            z = torch.randn(num_samples, latent_dim, device=device)
            with torch.no_grad():
                gen_imgs = generator(z).to(device)
                gen_imgs = gen_imgs.reshape(gen_imgs.size(0), 2)
        elif self.dataset_name == 'Cifar10':
            z = torch.randn(num_samples, latent_dim, 1, 1, device=device)
            with torch.no_grad():
                gen_imgs = generator(z).to(device)
            gen_imgs = gen_imgs.reshape(gen_imgs.size(0), 3, 32, 32)  # 32x32 for CIFAR-10
    

        if save_img:
            os.makedirs(os.path.dirname(img_path), exist_ok=True)
            num_to_save = min(num_samples, (num_samples // nrow) * nrow)
            save_image(gen_imgs.data[:num_to_save], img_path, nrow=nrow, normalize=True)
        
        if save:
            os.makedirs(os.path.dirname(data_path), exist_ok=True)
            # Ensure tensor is on CPU before saving
            with h5py.File(data_path, 'w') as f:
                f.create_dataset('generated_data', data=self._tensor_to_numpy(gen_imgs))
            print(f'Generated data saved to {data_path}.')
        return gen_imgs
    

    def save_generated_images(self, gen_data, epoch, name='', nrow =5, data_num =25):
        img_path = f'{self.main_folder}/images/generated_image_{epoch}epochs{name}.png'
        os.makedirs(os.path.dirname(img_path), exist_ok=True)
        gen_data_tensor = torch.tensor(gen_data)
        save_image(gen_data_tensor.data[:data_num], img_path, nrow=nrow, normalize=True)