import sys
import os
import torchvision.transforms as transforms
from torchvision import datasets
import torch

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
# from CelebA_Experiment.architectures.improved_wgangp import Generator
# from CelebA_Experiment.architectures.improved_wgangp import Critic
from CelebA_Experiment.architectures.conv import Generator, Critic, initialize_weights
from magnitude import *
from model_loader import ModelLoader
import time
from torch.utils.data import Dataset

class RobustCelebADataset(Dataset):
    """Wrapper that retries failed file reads and skips persistently broken files"""
    
    def __init__(self, celeba_dataset, max_retries=3, retry_delay=0.1):
        self.dataset = celeba_dataset
        self.max_retries = max_retries
        self.retry_delay = retry_delay
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        # Try multiple times with exponential backoff
        for attempt in range(self.max_retries):
            try:
                return self.dataset[idx]
            except FileNotFoundError as e:
                if attempt < self.max_retries - 1:
                    # Wait a bit and retry (exponential backoff)
                    time.sleep(self.retry_delay * (2 ** attempt))
                    continue
                else:
                    # Last attempt failed - try a different random index
                    print(f"Warning: Failed to load index {idx} after {self.max_retries} attempts. Using random sample instead.")
                    import random
                    new_idx = random.randint(0, len(self.dataset) - 1)
                    # Try the new index once (don't recurse to avoid infinite loops)
                    try:
                        return self.dataset[new_idx]
                    except:
                        # If even that fails, return a black image
                        print(f"Warning: Fallback sample also failed. Returning dummy data.")
                        # Return dummy data matching expected format
                        dummy_img = torch.zeros(3, 64, 64)  # Adjust size to match your images
                        dummy_label = 0
                        return dummy_img, dummy_label

class GAN:
    def __init__(self, batch_size=32, sample_size=36, lr=0.1e-04, weight_decay=1e-04, beta1=0.5, beta2=.999, n_cpu: int = 8, latent_dim: int = 100, img_size: int = 64, channels: int = 3, n_critic: int = 5, step: int = 50, device=None, name: str = 'WGAN_improved', dataset_name: str = 'CELEBA'):
        self.batch_size = batch_size
        self.sample_size = sample_size
        self.lr = lr
        self.weight_decay = weight_decay
        self.beta1 = beta1
        self.beta2 = beta2
        self.n_cpu = n_cpu
        self.latent_dim = latent_dim
        self.img_size = img_size
        self.channels = channels
        self.n_critic = n_critic
        self.step = step
        self.device = torch.device(device if device else ("cuda" if torch.cuda.is_available() else "cpu"))
        self.dataset_name = dataset_name 
        try:
            if dataset_name == 'CELEBA':
                print(f'Preparing {dataset_name} dataset')
                self.img_shape = (self.channels, self.img_size, self.img_size)
                # SCRATCH_dir = os.environ.get('TMPDIR', '/tmp')
                # celeba_dir = os.path.join(SCRATCH_dir, "celeba_data")

                # SCRATCH_dir = "/disk/scratch/s2670758"
                # celeba_dir = os.path.join(SCRATCH_dir, "celeba_data")
                celeba_dir = os.environ.get("DATA_ROOT", "/home/s2670758/celeba_data")
                
                # celeba_dir = os.path.join(os.path.dirname(__file__), '../celeba_data')
                # celeba_dir = os.path.abspath(celeba_dir)
                os.makedirs(celeba_dir, exist_ok=True)
                
                # Check for images
                img_dir = os.path.join(celeba_dir, 'celeba', 'img_align_celeba')
                if not os.path.exists(img_dir):
                    print("Extracting CelebA dataset...")
                    import zipfile
                    zip_path = os.path.join(celeba_dir, 'celeba', 'img_align_celeba.zip')
                    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
                        zip_ref.extractall(os.path.join(celeba_dir, 'celeba'))
                
                # Check for required annotation files
                anno_dir = os.path.join(celeba_dir, 'celeba')
                required_files = ['list_attr_celeba.txt', 'identity_CelebA.txt', 'list_bbox_celeba.txt', 'list_landmarks_align_celeba.txt', 'list_eval_partition.txt']
                
                missing_files = [f for f in required_files if not os.path.exists(os.path.join(anno_dir, f))]
                
                if missing_files:
                    print(f"Missing annotation files: {missing_files}")
                    print("Downloading annotation files...")
                    from torchvision.datasets.utils import download_url
                    
                    base_url = "https://drive.google.com/uc?export=download&id="
                    file_ids = {
                        'list_attr_celeba.txt': '0B7EVK8r0v71pblRyaVFSWGxPY0U',
                        'identity_CelebA.txt': '1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS',
                        'list_bbox_celeba.txt': '0B7EVK8r0v71pbThiMVRxWXZ4dU0',
                        'list_landmarks_align_celeba.txt': '0B7EVK8r0v71pd0FJY3Blby1HUTQ',
                        'list_eval_partition.txt': '0B7EVK8r0v71pY0NSMzRuSXJEVkk'
                    }
                    
                    for filename in missing_files:
                        if filename in file_ids:
                            try:
                                url = base_url + file_ids[filename]
                                print(f"Downloading {filename}...")
                                download_url(url, anno_dir, filename=filename)
                            except Exception as e:
                                print(f"Failed to download {filename}: {e}")
                                print(f"Please download manually from Google Drive and place in {anno_dir}")
                
                
                dataset = datasets.CelebA(
                    celeba_dir,
                    split='train',
                    download=False,
                    transform = transforms.Compose([
                        transforms.Resize([64,64]),
                        transforms.ToTensor(),
                        transforms.Normalize([0.5 for i in range(self.channels)], [0.5 for i in range(self.channels)])
                    ]),
                )
                dataset = RobustCelebADataset(dataset, max_retries=5, retry_delay=0.2)
                self.dataloader = torch.utils.data.DataLoader(
                    dataset,
                    batch_size=self.batch_size,
                    shuffle=True,
                    num_workers=0,
                    pin_memory=True
                )
        except Exception as e:
            print(f"Error preparing CelebA dataset: {e}")
            pass  

        self.model_name = name
        # if name.endswith('improved'):
        #     self.channel_size = 64
        #     self.G = Generator(self.latent_dim, self.img_size, self.channels, self.channel_size).to(self.device)
        #     self.C = Critic(self.img_size, self.channels, self.channel_size).to(self.device)

        #     self.optimizer_G = torch.optim.Adam(self.G.parameters(), lr=self.lr, betas=(self.beta1, self.beta2), weight_decay=self.weight_decay)
        #     self.optimizer_C = torch.optim.Adam(self.C.parameters(), lr=self.lr, betas=(self.beta1, self.beta2), weight_decay=self.weight_decay)
        #         self.model_name = name
            
        if 'conv' in name:
            self.features = 64
            self.G = Generator(z_dim = self.latent_dim, channels_img = self.channels, features_g = self.features).to(device)
            self.C = Critic(self.channels, features_c = self.features).to(device)
            initialize_weights(self.G)
            initialize_weights(self.C)

            self.optimizer_G = torch.optim.RMSprop(self.G.parameters(), lr=self.lr)
            self.optimizer_C = torch.optim.RMSprop(self.C.parameters(), lr=self.lr)

        self.loss_C_list = []
        self.loss_G_list = []
        self.generator_grad_norm_list = []
        self.gen_data_list = []

    #Calculates the gradient penalty loss for WGAN GP
    def gradient_penalty(self, critic, real_samples, fake_samples, device):
        alpha = torch.rand(real_samples.size(0), 1, 1, 1).to(device)
        interpolates = (alpha * real_samples + (1 - alpha) * fake_samples).requires_grad_(True)
        d_interpolates = critic(interpolates)

        fake = torch.ones(d_interpolates.shape).to(device)
        gradients = torch.autograd.grad(
            outputs=d_interpolates,
            inputs=interpolates,
            grad_outputs=fake,
            create_graph=True,
            retain_graph=True,
            only_inputs=True,
        )[0]

        gradients = gradients.view(gradients.size(0), -1)
        gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
        return gradient_penalty

    def train_WGAN(self, n_epochs, n_critic, batch_size: int = 64, clip_value=0.01): 
        self.n_epochs = n_epochs
        self.clip_value = clip_value
        self.model_loader = ModelLoader(self.dataset_name, self.model_name, self.device, self.step)
        self.model_loader.set_paths(self.n_epochs)

        if os.path.exists(self.model_loader.critic_path) and os.path.exists(self.model_loader.generator_path) and os.path.exists(self.model_loader.list_path):
            return self.model_loader.load_model(self.n_epochs, self.C, self.G)
        print(f'Critic path: {self.model_loader.critic_path}, exist: {os.path.exists(self.model_loader.critic_path)}')
        print(f'Generator path: {self.model_loader.generator_path}, exist: {os.path.exists(self.model_loader.generator_path)}')
        print(f'Loss and gradient norms path: {self.model_loader.list_path}, exist: {os.path.exists(self.model_loader.list_path)}')
        print(f'Starting training {self.model_name} for {self.n_epochs} epochs with batch size {batch_size} and {n_critic} critic updates.')


        # batches_done = 0
        for epoch in range(1, self.n_epochs+1):
            self.model_loader.set_paths(epoch)
            for i, (imgs, _) in enumerate(self.dataloader):  

                for _ in range(n_critic):
                    real_imgs = imgs.to(self.device)

                    # Sample noise as generator input
                    z = torch.randn(imgs.shape[0], self.latent_dim, 1, 1).to(self.device)
                    # Generate a batch of images
                    fake_imgs = self.G(z)

                    # Compute loss for critic
                    loss_C = -torch.mean(self.C(real_imgs).reshape(-1)) + torch.mean(self.C(fake_imgs).reshape(-1))

                    self.optimizer_C.zero_grad()
                    self.optimizer_G.zero_grad()
                    loss_C.backward()

                    self.optimizer_C.step()
                    # Clip critic weights, note that this is not the same as the gradient clipping
                    for p in self.C.parameters():
                        p.data.clamp_(-self.clip_value, self.clip_value)
                    

                z = torch.randn(imgs.shape[0], self.latent_dim, 1, 1).to(self.device)
                gen_imgs = self.G(z)
                loss_G = -torch.mean(self.C(gen_imgs).reshape(-1))

                self.optimizer_C.zero_grad()
                self.optimizer_G.zero_grad()
                loss_G.backward()
                total_norm = self._calculate_gradient_norm(self.G)
                self.optimizer_G.step()

                self.loss_C_list.append(loss_C.item())
                self.loss_G_list.append(loss_G.item())
                self.generator_grad_norm_list.append(total_norm)


            if (epoch) % self.step == 0:
                self.gen_data = self.model_loader.generated_images(self.G, self.latent_dim, self.device, epoch, num_samples=50, save=False)
                self.gen_data_list.append(self.gen_data)
                # Save model, losses, and generated data
                self.model_loader.save_model(self, epoch, self.batch_size, self.n_critic, 
                clip_value = self.clip_value,  info=False) # Use clip_value for GAN_mnist 
                print(f"Epoch {epoch}/{self.n_epochs}| D loss: {loss_C.item():.6f} | G loss: {loss_G.item():.6f} | G grad norm: {total_norm:.4f}")


                # if batches_done % (self.step*8)== 0:
                #     save_image(gen_imgs.data[:25], f"images/{batches_done}.png", nrow=5, normalize=True)
                # batches_done += 1

        self.model_loader.save_model(self, self.n_epochs, batch_size, n_critic,
            clip_value = self.clip_value, info=True)
        return self.loss_C_list, self.loss_G_list, self.generator_grad_norm_list, self.gen_data_list

    def train_WGAN_GP(self, n_epochs, n_critic, batch_size: int = 64, lambda_gp = 10): 
        self.n_epochs = n_epochs
        self.lambda_gp = lambda_gp
        self.model_loader = ModelLoader(self.dataset_name, self.model_name, self.device, self.step)
        self.model_loader.set_paths(self.n_epochs)

        if os.path.exists(self.model_loader.critic_path) and os.path.exists(self.model_loader.generator_path) and os.path.exists(self.model_loader.list_path):
            return self.model_loader.load_model(self.n_epochs, self.C, self.G)
        print(f'Critic path: {self.model_loader.critic_path}, exist: {os.path.exists(self.model_loader.critic_path)}')
        print(f'Generator path: {self.model_loader.generator_path}, exist: {os.path.exists(self.model_loader.generator_path)}')
        print(f'Loss and gradient norms path: {self.model_loader.list_path}, exist: {os.path.exists(self.model_loader.list_path)}')
        print(f'Starting training {self.model_name} for {self.n_epochs} epochs with batch size {batch_size} and {n_critic} critic updates.')


        # batches_done = 0
        for epoch in range(1, self.n_epochs+1):
            self.model_loader.set_paths(epoch)
            for i, (imgs, _) in enumerate(self.dataloader):               
                for _ in range(n_critic):
                    real_imgs = imgs.to(self.device)

                    # Sample noise as generator input
                    z = torch.randn(imgs.shape[0], self.latent_dim, 1, 1).to(self.device)
                    # Generate a batch of images
                    fake_imgs = self.G(z)

                    # Compute loss for critic
                    gp = self.gradient_penalty(self.C, real_imgs, fake_imgs, self.device)
                    penalty = lambda_gp * gp
                    loss_C = -torch.mean(self.C(real_imgs).reshape(-1)) + torch.mean(self.C(fake_imgs).reshape(-1)) + penalty

                    self.optimizer_C.zero_grad()
                    self.optimizer_G.zero_grad()
                    loss_C.backward()

                    self.optimizer_C.step()

            
                z = torch.randn(imgs.shape[0], self.latent_dim, 1, 1).to(self.device)
                gen_imgs = self.G(z)
                loss_G = -torch.mean(self.C(gen_imgs).reshape(-1))

                self.optimizer_C.zero_grad()
                self.optimizer_G.zero_grad()
                loss_G.backward()
                total_norm = self._calculate_gradient_norm(self.G)
                self.optimizer_G.step()

                self.loss_C_list.append(loss_C.item())
                self.loss_G_list.append(loss_G.item())
                self.generator_grad_norm_list.append(total_norm)


            if (epoch) % self.step == 0:
                self.gen_data = self.model_loader.generated_images(self.G, self.latent_dim, self.device, epoch, num_samples=50, save=False)
                self.gen_data_list.append(self.gen_data)
                # Save model, losses, and generated data
                self.model_loader.save_model(self, epoch, self.batch_size, self.n_critic, 
                lambda_gp = self.lambda_gp,  info=False) # Use clip_value for GAN_mnist 
                print(f"Epoch {epoch}/{self.n_epochs}| D loss: {loss_C.item():.6f} | G loss: {loss_G.item():.6f} | G grad norm: {total_norm:.4f}")


        self.model_loader.save_model(self, self.n_epochs, batch_size, n_critic,
            lambda_gp = self.lambda_gp, info=True)
        return self.loss_C_list, self.loss_G_list, self.generator_grad_norm_list, self.gen_data_list


    def _calculate_gradient_norm(self, model):
        """Calculate the total gradient norm for a model."""
        total_norm = 0
        for p in model.parameters():
            if p.grad is not None:
                param_norm = p.grad.data.norm(2)
                total_norm += param_norm.item() ** 2
        return total_norm ** 0.5


    def compute_magnitude_overlap(self, epoch, max_t = 5, min_t = 0, steps = 10, num_samples = 1000, normalize = False):
        print(f'Computing magnitude overlap at epoch {epoch} for {self.model_name}', flush=True)
        gen_data = self.model_loader.generated_images(self.G, self.latent_dim, self.device, epoch, num_samples=num_samples, save=False)

        for imgs, _ in self.dataloader:
            real_imgs = imgs.to(self.device)
            break

        max_overlap, t_arg_max = norm_max_magnitude_overlap_grad(real_imgs.view(real_imgs.size(0), -1), gen_data.view(gen_data.size(0), -1), device=self.device, normalize=normalize, eps = 0, max_t = max_t, min_t = min_t, steps = steps)
        print(f'Max magnitude overlap at t={t_arg_max}: {max_overlap}', flush=True)
        return max_overlap, t_arg_max
