import torch
import os
import argparse
from PIL import Image
from pathlib import Path
from torch.utils.data import Dataset
import lightning as L
import transformers
from lightning.fabric.strategies import DeepSpeedStrategy
import torch.nn.functional as F
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
from peft import LoraConfig, get_peft_model


class NEWModel_CLIP(torch.nn.Module):
    def __init__(self, model, old_dim, new_dim, logit_scale_init_value, hidden_size, n_prompts=10):
        super().__init__()
        self.model = model

        self.Phi = MLPModel(input_dim=new_dim, hidden_dim=old_dim*4, output_dim=old_dim, dropout=0.0)
        self.Psi = MLPModel(input_dim=new_dim, hidden_dim=old_dim*4, output_dim=old_dim, dropout=0.0)
        self.prompts = torch.nn.Parameter(torch.randn(n_prompts, hidden_size))
        self.logit_scale = torch.nn.Parameter(torch.tensor(logit_scale_init_value))

    def forward_image(self, image):
        patch_embeds = self.model.vision_model.embeddings(image)
        prompts = self.prompts.unsqueeze(0).expand(len(image), -1, -1)
        combined_embeds = torch.cat([patch_embeds, prompts], dim=1)
        combined_embeds = self.model.vision_model.pre_layrnorm(combined_embeds)
        encoder_outputs = self.model.vision_model.encoder(combined_embeds)

        last_hidden_state = encoder_outputs[0]
        pooled_output = last_hidden_state[:, 0, :]
        pooled_output = self.model.vision_model.post_layernorm(pooled_output)

        x_i = self.model.visual_projection(pooled_output)
        return x_i

    def forward_text(self, input_ids, attn):
        text_outputs = self.model.text_model(
            input_ids=input_ids,
            attention_mask=attn
        )

        pooled_output = text_outputs[1]
        x_t = self.model.text_projection(pooled_output)
        return x_t

class DLoader(Dataset):
    def __init__(self, test_list, new_processor, old_size, new_size):
        self.test_list = test_list
        self.new_processor = new_processor

    def __len__(self):
        return len(self.test_list)
    
    def _load_image(self, id: int):
        return Image.open(self.test_list[id]["filename"]).convert("RGB")

    def _load_target(self, id: int):
        return self.test_list[id]["caption"]
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        image = self._load_image(idx)
        caption = self._load_target(idx)

        new_data = self.new_processor(images=image, text=caption, return_tensors="pt", truncation = True, padding = "max_length")
        new_image = new_data['pixel_values'][0]

        new_text = new_data['input_ids'][0]
        new_attn = new_data['attention_mask'][0]

        return new_image, new_text, new_attn
    
def collate_fn(batch):
    batch = list(filter(lambda x: x is not None, batch))
    return torch.utils.data.dataloader.default_collate(batch)

def get_args_parser():
    parser = argparse.ArgumentParser('XBT Learning', add_help=False)
    parser.add_argument('--batch_size', default=128, type=int, help='Batch size per GPU')
    parser.add_argument('--epochs', default=1, type=int)
    parser.add_argument('--image_size', default=224, type=int)
    parser.add_argument('--dataset', default='dataset', type=str)
    parser.add_argument('--old_model', default='openai/clip-vit-base-patch32', type=str)
    parser.add_argument('--new_model', default='openai/clip-vit-large-patch14', type=str)
    parser.add_argument('--n_prompts', default=10, type=int)
    parser.add_argument('--ckpt', default='', type=str, help='Path to pretrained Phi module')
    parser.add_argument('--weight_decay', type=float, default=0.01, 
                        help='weight decay (default: 0.01)')
    parser.add_argument('--init_lr', type=float, default=1e-4, metavar='LR')
    parser.add_argument('--output_dir', default='out', help='path where to save, empty for no saving')
    parser.add_argument('--num_workers', default=12, type=int)
    parser.add_argument('--world_size', default=8, type=int, help='number of distributed processes')

    return parser

def main(args):

    ds_config = {
    "train_micro_batch_size_per_gpu": args.batch_size,
    "zero_optimization": {"stage": 2},
    }

    fabric = L.Fabric(
        accelerator="cuda", 
        devices=args.world_size,
        strategy=DeepSpeedStrategy(config=ds_config), 
        precision="bf16"
    )
    fabric.launch()
    fabric.seed_everything(1337 + fabric.global_rank)

    if fabric.global_rank == 0:
        os.makedirs(args.output_dir, exist_ok=True)

    train_list = args.dataset

    with open(train_list) as f:
        train_list = json.load(f)

    old_name = args.old_model
    new_name = args.new_model

    old_config = transformers.AutoConfig.from_pretrained(old_name)
    new_config = transformers.AutoConfig.from_pretrained(new_name)

    args.new_model = args.new_model.lower()
    args.old_model = args.old_model.lower()

    new_dim = new_config.projection_dim
    old_dim = old_config.projection_dim

    lora_config = LoraConfig(
            r=16,
            lora_alpha=16,
            target_modules=["q_proj","v_proj"],
            lora_dropout=0.1,
            bias="none",
        )

    with fabric.device:
        new_processor = transformers.AutoProcessor.from_pretrained(new_name)
        new_model = transformers.CLIPModel.from_pretrained(new_name)
        new_model = get_peft_model(new_model, lora_config)
        for name, param in new_model.named_parameters():
            if "projection" in name:
                param.requires_grad = True
        new_model = NEWModel_CLIP(new_model, old_dim, new_dim, new_config.logit_scale_init_value, new_config.vision_config.hidden_size, args.n_prompts).bfloat16()

        state_dict = torch.load(args.ckpt)
        new_model.Phi.load_state_dict(phi_state_dict)

        for name, param in new_model.Phi.named_parameters():
            param.requires_grad = False

        # Unfreeze LayerNorm parameters
        for layer in new_model.Phi.layers:
            if isinstance(layer, torch.nn.LayerNorm):
                for param in layer.parameters():
                    param.requires_grad = True

    print_trainable_parameters(fabric, new_model)

    dataset_train = DLoader(train_list, new_processor, old_config.vision_config.image_size, new_config.vision_config.image_size)
    
    train_loader = torch.utils.data.DataLoader(
        dataset_train, batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        drop_last=True,
        shuffle=True,
        collate_fn=collate_fn,
    )

    train_loader = fabric.setup_dataloaders(train_loader)

    optimizer = torch.optim.AdamW(param_groups, lr=args.init_lr, weight_decay=args.weight_decay)
    new_model, optimizer = fabric.setup(new_model, optimizer)
    
    train(fabric, new_model, optimizer, train_loader)

def train(
    fabric: L.Fabric,
    new_model: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    train_loader,
) -> None:
    
    new_model.train(True)

    iter = 0
    total_iter = len(train_loader) * args.epochs

    start_time = time.time()
    
    for epoch in range(args.epochs):
        optimizer.zero_grad()

        for samples in train_loader:
            new_image, new_text, new_attn = samples
            
            new_x_i = new_model.forward_image(new_image.bfloat16())            
            new_x_t = new_model.forward_text(new_text, new_attn)

            new_x_i = new_model.Phi(F.normalize(new_x_i))
            new_x_i = batch_align(fabric, new_x_i)
            
            new_x_t = new_model.Phi(F.normalize(new_x_t))
            new_x_t = batch_align(fabric, new_x_t)

            logit_scale = new_model.logit_scale.exp()

            btloss_IT = clip_loss(logit_scale * F.normalize(new_x_i) @ F.normalize(new_x_t).t())
            loss = btloss_IT

            fabric.print(f"epoch {epoch} iter {iter} ({(iter/total_iter)*100:.4f}%) bt_IT {btloss_IT.item():.4f}")

            fabric.backward(loss, model=new_model)
            optimizer.step()
            
            iter += 1

        save_path = os.path.join(args.output_dir)
        fabric.save(save_path, {"model": new_model})
        fabric.barrier()

if __name__ == "__main__":
    args = get_args_parser()
    args = args.parse_args()
    if args.output_dir:
        Path(args.output_dir).mkdir(parents=True, exist_ok=True)
    main(args)