import sys
import os
import math
import torchvision.transforms as transforms
import torch
import torch.nn.functional as F
import time
from torch.utils.data import Dataset
from torchvision.datasets import CIFAR10
import torch.optim as optim
from torchvision import models

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from Cifar10_Experiment.architectures.conv import Generator, Critic, weights_init

from magnitude import *
from model_loader import ModelLoader


class GAN:
    def __init__(self, batch_size=64, lr=1e-4, opt_betas=(0.5, 0.999), n_cpu: int = 4, latent_dim: int = 100, img_size: int = 32, channels: int = 3, n_critic: int = 5, step: int = 50, device=None, name: str = 'WGAN_improved', dataset_name: str = 'Cifar10'):
        self.batch_size = batch_size
        self.lr = lr
        self.opt_betas = opt_betas
        self.n_cpu = n_cpu
        self.latent_dim = latent_dim
        self.img_size = img_size
        self.channels = channels
        self.n_critic = n_critic
        self.step = step
        self.device = torch.device(device if device else ("cuda" if torch.cuda.is_available() else "cpu"))
        self.dataset_name = dataset_name 

        
        if dataset_name == 'Cifar10':
            print(f'Preparing {dataset_name}')
            self.img_shape = (self.channels, self.img_size, self.img_size)
            
            # Data preprocessing to set the image values from range [0, 1] to [-1, 1]
            preprocessing_ops = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])  # Normalize to [-1, 1]
            ])
            
            # Dataset download
            cifar_dir = os.path.join(os.path.dirname(__file__), '../cifar_data')
            cifar_dir = os.path.abspath(cifar_dir)
            os.makedirs(cifar_dir, exist_ok=True)
            
            train_dataset = CIFAR10(
                root=cifar_dir,
                train=True,
                transform=preprocessing_ops,
                download=True
            )
            
            class_label = 3  # cat
            train_idx = [i for i, (_, label) in enumerate(train_dataset) if label == class_label]
            train_dataset = torch.utils.data.Subset(train_dataset, train_idx)
            
            self.dataloader = torch.utils.data.DataLoader(
                train_dataset,
                batch_size=self.batch_size,
                shuffle=True,
                num_workers=self.n_cpu,
                pin_memory=True  # Faster data transfer to GPU
            )
            
            print(f"CIFAR-10 dataloader created with {len(train_dataset)} images")
     

        self.model_name = name

        if 'conv' in name:
            gen_filters = (1024, 512, 256)
            # disc_filters = (256, 512, 1024)
            leaky_relu_alpha = 0.2
            

            self.G = Generator(gen_filters, leaky_relu_alpha=leaky_relu_alpha, latent_dim=latent_dim).to(self.device)
            # self.C = Critic(disc_filters, leaky_relu_alpha=leaky_relu_alpha).to(self.device)
            self.G.apply(weights_init)
            # self.C.apply(weights_init)

            self.optimizer_G = optim.Adam(self.G.parameters(), lr=lr, betas=opt_betas)
            # self.optimizer_C = optim.Adam(self.C.parameters(), lr=lr, betas=opt_betas)


        self.loss_G_list = []
        self.generator_grad_norm_list = []
        self.gen_data_list = []

    def train_MagGN_multiscale(self, n_epochs, batch_size: int = 64,  normalize=False, t_list = None, epoch_list = None, loss_normalize=False, feature_space=True, avg_pool_size=4, hybrid_coeff = 0): 
        self.n_epochs = n_epochs
        # Set model name based on parameters
        self.model_name = self._get_maggn_model_name(normalize = normalize, name = self.model_name, t_list = t_list, epoch_list = epoch_list, loss_normalize=loss_normalize, feature_space=feature_space, avg_pool_size=avg_pool_size, hybrid_coeff=hybrid_coeff)
        self.model_loader = ModelLoader(self.dataset_name, self.model_name, self.device, self.step)
        self.model_loader.set_paths(self.n_epochs, generator_name = 'MagGN_multiscale')

        if os.path.exists(self.model_loader.generator_path) and os.path.exists(self.model_loader.list_path):
            return self.model_loader.load_model(epochs=self.n_epochs, critic=None, generator=self.G)
        print(f'Generator path: {self.model_loader.generator_path}, exist: {os.path.exists(self.model_loader.generator_path)}')
        print(f'Loss and gradient norms path: {self.model_loader.list_path}, exist: {os.path.exists(self.model_loader.list_path)}')
        print(f'Starting training {self.model_name} for {self.n_epochs} epochs with batch size {batch_size}..')
        
        sorted_t_list = sorted(t_list)
        t_schedule_list = []

        if epoch_list is None:
            interval_step = int(n_epochs / len(t_list))
            epoch_list = [i * interval_step + 1 for i in range(len(t_list))]
            
        print(f'Interval for adding t values: {epoch_list} epochs.')
        if feature_space:
            resnet = models.resnet18(pretrained=True)
            # Remove the final FC layer and avgpool
            self.feature_extractor = torch.nn.Sequential(*list(resnet.children())[:-2]).to(self.device)
            self.feature_extractor.eval()
            for param in self.feature_extractor.parameters():
                param.requires_grad = False
            print("Feature extractor initialized.")

        for epoch in range(1, self.n_epochs+1):
            # Log beta changes
            self.model_loader.set_paths(epoch, generator_name = 'MagGN_multiscale')
            
            if epoch in epoch_list:
                num_t = epoch_list.index(epoch) + 1
                print(f"Epoch {epoch}: Adding magnitude with {num_t} t values {sorted_t_list[:num_t]} for loss computation.")
                t_schedule_list = sorted_t_list[:num_t]
            
            for i, (imgs, _) in enumerate(self.dataloader):
                real_imgs = imgs.to(self.device)
                z = torch.randn(imgs.shape[0], self.latent_dim, 1, 1).to(self.device)
                gen_imgs = self.G(z)
                real_input = real_imgs.view(gen_imgs.size(0), -1)
                gen_input = gen_imgs.view(gen_imgs.size(0), -1)

                if feature_space:
                    with torch.no_grad():
                        real_features = self.feature_extractor(real_imgs) # (batch, 512, 4, 4) for ResNet-18
                        if avg_pool_size > 0:
                            real_features = F.adaptive_avg_pool2d(real_features, (avg_pool_size, avg_pool_size))  # (batch, 256, 1, 1)
                        real_features = real_features.view(real_features.size(0), -1)
                        real_input = real_features
                    gen_features = self.feature_extractor(gen_imgs)
                    if avg_pool_size > 0:
                        gen_features = F.adaptive_avg_pool2d(gen_features, (avg_pool_size, avg_pool_size))    # (batch, 256, 1, 1)
                    gen_features = gen_features.view(gen_features.size(0), -1) 
                    gen_input = gen_features


                loss_G = 0
                for t in t_schedule_list:
                    # loss_distance = norm_diff_magnitude_distance_grad(real_imgs.view(gen_imgs.size(0), -1), gen_imgs.view(gen_imgs.size(0), -1), device=self.device, t=t, normalize=normalize, eps = 0)
                    loss_distance = norm_diff_magnitude_distance_grad(real_input, gen_input, device=self.device, t=t, normalize=normalize, eps = 0)
                    loss_G += loss_distance
                    if hybrid_coeff > 0:
                        loss_pixel = norm_diff_magnitude_distance_grad(real_imgs.view(real_imgs.size(0), -1), gen_imgs.view(gen_imgs.size(0), -1), device=self.device, t=t, normalize=normalize, eps = 0)
                        loss_G += hybrid_coeff * loss_pixel
                if loss_normalize:
                    loss_G = loss_G / len(t_schedule_list)
                else:
                    loss_G = loss_G / len(t_list)

                self.optimizer_G.zero_grad()
                loss_G.backward()

                total_norm = self._calculate_gradient_norm(self.G)
                self.optimizer_G.step()

                self.loss_G_list.append(loss_G.item())
                self.generator_grad_norm_list.append(total_norm)

            if (epoch) % self.step == 0:
                self.gen_data = self.model_loader.generated_images(self.G, self.latent_dim, self.device, epoch, num_samples=64, save=False, nrow=8)
                self.gen_data_list.append(self.gen_data)
                # Save model, losses, and generated data
                self.model_loader.save_model(self, epoch, self.batch_size, info=False) # Use clip_value for GAN_mnist 
                print(f"Epoch {epoch}/{self.n_epochs}| G loss: {loss_G.item():.6f} | G grad norm: {total_norm:.4f}")

        self.model_loader.save_model(self, self.n_epochs, batch_size, info=True)
        return self.loss_G_list, self.generator_grad_norm_list, self.gen_data_list
         

    def _calculate_gradient_norm(self, model):
        """Calculate the total gradient norm for a model."""
        total_norm = 0
        for p in model.parameters():
            if p.grad is not None:
                param_norm = p.grad.data.norm(2)
                total_norm += param_norm.item() ** 2
        return total_norm ** 0.5

    def _get_maggn_model_name(self, normalize, name, t_list, epoch_list, loss_normalize, feature_space, avg_pool_size, hybrid_coeff):
        """Generate model name for Magnitude GAN based on parameters."""
        if t_list is not None:
            t_list_str = '_'.join([str(t) for t in t_list])
            self.model_name = f'{self.model_name}-t_multiscale-{t_list_str}'
            if epoch_list is not None:
                epoch_list_str = '_'.join([str(e) for e in epoch_list])
                self.model_name = f'{self.model_name}-epochs-{epoch_list_str}'
        
        if normalize:
            self.model_name = f'Norm{self.model_name}'
        
        if loss_normalize:
            self.model_name = f'{self.model_name}_with_normalized_loss'
        
        if feature_space:
            self.model_name = f'feature_space_{self.model_name}_with_avgpool_{avg_pool_size}'
            if hybrid_coeff > 0:
                self.model_name = f'{self.model_name}_hybridcoeff_{hybrid_coeff}'

        return self.model_name
    
    def compute_magnitude_overlap(self, epoch, max_t = 5, min_t = 0, steps = 10, num_samples = 1000, normalize = False):
        print(f'Computing magnitude overlap at epoch {epoch} for {self.model_name}', flush=True)
        gen_data = self.model_loader.generated_images(self.G, self.latent_dim, self.device, epoch, num_samples=num_samples, save=False)

        for imgs, _ in self.dataloader:
            real_imgs = imgs.to(self.device)
            break

        max_overlap, t_arg_max = norm_max_magnitude_overlap_grad(real_imgs.view(real_imgs.size(0), -1), gen_data.view(gen_data.size(0), -1), device=self.device, normalize=normalize, eps = 0, max_t = max_t, min_t = min_t, steps = steps)
        print(f'Max magnitude overlap at t={t_arg_max}: {max_overlap}', flush=True)
        return max_overlap, t_arg_max
 
    
