import os
import argparse
import random
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder, DatasetFolder
from torch.amp import GradScaler, autocast
from torch.autograd import Function
from tqdm import tqdm
from PIL import Image
import json

import open_clip

def set_seed(seed):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)

class ReverseLayerF(Function):
    @staticmethod
    def forward(ctx, x, alpha):
        ctx.alpha = alpha
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        output = grad_output.neg() * ctx.alpha
        return output, None

class CLIP_with_discriminator(nn.Module):
    def __init__(self, clip_model):
        super(CLIP_with_discriminator, self).__init__()
        self.clip = clip_model
        self.discriminator_head = nn.Linear(768, 2)
    
    def forward(self, image, text):
        return self.clip(image, text)
    
    def discriminate(self, feature, alpha):
        reversed_input = ReverseLayerF.apply(feature, alpha)
        return self.discriminator_head(reversed_input)

class DomainNetDataset(DatasetFolder):
    def __init__(self, root, preprocess):
        self.root = root
        self.imagefolder = ImageFolder(root, transform=preprocess)
        self.samples = self.imagefolder.samples

    def __getitem__(self, index):
        path, _ = self.samples[index]
        sample = self.imagefolder.loader(path)
        if self.imagefolder.transform is not None:
            sample = self.imagefolder.transform(sample)
        return sample

    def __len__(self):
        return len(self.samples)

class GCCDataset(DatasetFolder):
    def __init__(self, root, preprocess, tokenizer):
        self.root = root
        self.samples = json.load(open(os.path.join(self.root, '..' ,'metadata.json')))
        self.preprocess = preprocess
        self.tokenizer = tokenizer

    def __getitem__(self, index):
        sample = self.preprocess(Image.open(os.path.join(self.root, self.samples[index]['image'])))
        tokenized_caption = self.tokenizer(self.samples[index]['caption']).squeeze(0)
        return sample, tokenized_caption
        
    def __len__(self):
        return len(self.samples)

def get_batch(dataloader_iter, dataloader):
    try:
        return next(dataloader_iter), dataloader_iter
    except StopIteration:
        dataloader_iter = iter(dataloader)
        return next(dataloader_iter), dataloader_iter

def classification_loss(image, text, logit_scale, device):
    logits_per_image = logit_scale * image @ text.T
    logits_per_text = logit_scale * text @ image.T
    labels = torch.arange(len(image), device=device, dtype=torch.long)
    return (F.cross_entropy(logits_per_image, labels) + F.cross_entropy(logits_per_text, labels)) / 2

def train_unlearn(args, model, tokenizer, gcc_loader, domainnet_loader, device):
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wd)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs)
    scaler = GradScaler('cuda')

    activation = {}
    def get_penultimate_hook(name):
        def hook(model, input, output):
            activation[name] = output.clone()
        return hook
    model.clip.visual.ln_post.register_forward_hook(get_penultimate_hook('penultimate'))

    gcc_iter = iter(gcc_loader)
    domainnet_iter = iter(domainnet_loader)
    
    total_steps = args.epochs * max(len(gcc_loader), len(domainnet_loader))

    for epoch in range(args.epochs):
        model.train()
        total_gcc_loss, total_domain_loss = 0, 0
        
        num_batches = max(len(gcc_loader), len(domainnet_loader))
        progress = tqdm(range(num_batches), desc=f"Epoch {epoch+1}/{args.epochs}")

        for i in progress:
            optimizer.zero_grad()
            
            p = (i + epoch * num_batches) / total_steps
            alpha = (2. / (1. + np.exp(-10 * p)) - 1) * args.alpha_factor

            with autocast('cuda', dtype=torch.bfloat16):
                # GCC (retention) loss
                gcc_images, gcc_captions = gcc_iter.next()
                gcc_images, gcc_captions = gcc_images.to(device), gcc_captions.to(device)
                image_features, text_features, logit_scale = model(gcc_images, gcc_captions)
                gcc_loss = classification_loss(image_features, text_features, logit_scale, device)

                # DomainNet (unlearning) loss
                dn_images = domainnet_iter.next()
                dn_images = dn_images.to(device)
                random_noise = torch.randn_like(dn_images)
                
                _ = model.clip.encode_image(torch.cat([dn_images, random_noise]))
                penultimate_features = activation['penultimate']
                
                domain_preds = model.discriminate(penultimate_features, alpha)
                domain_labels = torch.cat([torch.ones(len(dn_images), device=device, dtype=torch.long), 
                                           torch.zeros(len(random_noise), device=device, dtype=torch.long)])
                domain_loss = F.cross_entropy(domain_preds, domain_labels)

                loss = gcc_loss + domain_loss

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            total_gcc_loss += gcc_loss.item()
            total_domain_loss += domain_loss.item()
            progress.set_postfix(gcc_loss=f"{gcc_loss.item():.4f}", domain_loss=f"{domain_loss.item():.4f}")

        scheduler.step()
        avg_gcc_loss = total_gcc_loss / num_batches
        avg_domain_loss = total_domain_loss / num_batches
        print(f"Epoch {epoch+1} finished. Avg GCC Loss: {avg_gcc_loss:.4f}, Avg Domain Loss: {avg_domain_loss:.4f}")

        if (epoch + 1) % args.save_every == 0:
            torch.save(model.clip.state_dict(), os.path.join(args.weights_dir, f'UNLEARNED_MODEL_{epoch+1}e.pt'))

def main():
    parser = argparse.ArgumentParser(description="Unlearn a domain from CLIP.")
    parser.add_argument('--datadir', type=str, default='../data', help='Root data directory.')
    parser.add_argument('--unlearn_domain', type=str, default='domainnet_loo_test', help='Domain to unlearn.')
    parser.add_argument('--retention_data', type=str, default='eleven/Pretrain_CC3M/images', help='Path to retention dataset (GCC).')
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--num_workers', type=int, default=4)
    parser.add_argument('--epochs', type=int, default=20)
    parser.add_argument('--lr', type=float, default=1e-5)
    parser.add_argument('--wd', type=float, default=0.1)
    parser.add_argument('--alpha_factor', type=float, default=1.0, help='Factor for adversarial loss.')
    parser.add_argument('--weights_dir', type=str, default='./weights')
    parser.add_argument('--save_every', type=int, default=5)
    args = parser.parse_args()

    set_seed(args.seed)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    os.makedirs(args.weights_dir, exist_ok=True)

    clip_model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='openai')
    tokenizer = open_clip.get_tokenizer('ViT-B-32')
    model = CLIP_with_discriminator(clip_model).to(device)

    # Dataloaders
    gcc_dataset = GCCDataset(root=os.path.join(args.datadir, args.retention_data), preprocess=preprocess, tokenizer=tokenizer)
    gcc_loader = DataLoader(gcc_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, drop_last=True)

    domainnet_root = os.path.join(args.datadir, args.unlearn_domain)
    domainnet_datasets = [DomainNetDataset(root=os.path.join(domainnet_root, d), preprocess=preprocess) for d in os.listdir(domainnet_root) if os.path.isdir(os.path.join(domainnet_root, d))]
    domainnet_full_dataset = torch.utils.data.ConcatDataset(domainnet_datasets)
    domainnet_loader = DataLoader(domainnet_full_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, drop_last=True)

    train_unlearn(args, model, tokenizer, gcc_loader, domainnet_loader, device)
    
    torch.save(model.clip.state_dict(), os.path.join(args.weights_dir, 'UNLEARNED_MODEL_final.pt'))

if __name__ == "__main__":
    main()
