import torch
import torch.nn as nn
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import requests
import copy
from transformers import ViTImageProcessor, ViTModel
from torch.utils.data import DataLoader
import random
import types

T = 10 # Number of tasks

# Load CLIP model and processor
device = torch.device("cuda")

# model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
# clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')

class Model(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
        self.fc = nn.Linear(768, num_classes)
    
    def forward(self, **inputs):
        outputs = self.model(**inputs)
        return self.fc(outputs.pooler_output)

model = Model(10)
device = "cuda:1"
model = model.to(torch.float32)
model = model.to(device)

class Chicken(nn.Module):
    # optional cache so we don’t recreate the same subclass over and over
    _cls_cache: dict[type[nn.Module], type[nn.Module]] = {}

    # -------- object construction --------
    def __new__(cls, model: nn.Module, *args, **kwargs):
        if cls is Chicken:                       # only when user calls ContiLearn(…)
            base = type(model)

            # reuse cached subclass if it exists
            Wrapped = cls._cls_cache.get(base)
            if Wrapped is None:
                Wrapped = types.new_class(
                    f"Chicken Learning{base.__name__}",        # e.g. ContiVisionTransformer
                    (Chicken, base),             # MRO: ContiLearn → base model
                    {}
                )
                cls._cls_cache[base] = Wrapped

            # allocate instance of the *new* subclass
            inst = super().__new__(Wrapped)
            # copy every weight / buffer / attribute
            inst.__dict__.update(model.__dict__)
            return inst

        # if somebody subclasses ContiLearn explicitly, honour normal behaviour
        return super().__new__(cls)
    
    def __init__(self, model, device="cpu", init_val=0.1, max_mult=1, matching_texts=["layernorm", "bias", "embeddings", "layrnorm", "layer_norm"]):
        if not isinstance(model, nn.Module):
            raise TypeError("model is not an torch.nn.Module") 
    
        self.init_val = init_val
        self.matching_texts = matching_texts
        self.device = device

        self.base_params = copy.deepcopy(model.state_dict())
        self.decomposed_params = self.decompose(self.base_params, matching_texts)

        self.class_policy_map = {}

        self.learnable_params = {}
        self.trainable_params = {}
        self.num_params = 0
        self.max_mult = max_mult
        self.enable_mask = []
        self.new_mask_idx = 0
        self.selected_mask = -1 # no mask selected
        
    def set_mask(self, mask_num: int = 0):
        """
        select the chosen mask base on the mask number

        Args:
            mask_num (int): the mask num should be less than the total number of mask available
            
        Returns:
            boolean: Status on whether get mask succeed
        """
        try:
            self.enable_mask[mask_num]
        except IndexError:
            raise IndexError("the mask number is out of range")

        self.selected_mask = mask_num

        return True

    def get_mask(self, mask_num: int = -1):
        """
        Returns:
            state_dict: a state dict on the selected mask
        """
        try:
            self.enable_mask[mask_num]
        except IndexError:
            raise IndexError("the mask number is out of range")
        
        # set to latest mask if not specified
        if mask_num == -1: 
            self.selected_mask = self.latest_mask_idx
        
        return self.learnable_params[mask_num]
    
    def add_class(self, class_names):
        """
        Call this to add a new set of classes
        """
        for name in class_names:
            self.class_policy_map[name] = self.new_mask_idx

        mask_params = {}
        for k, v in self.base_params.items():
            # each param initialized with small gaussian noise
            if any(text in k for text in self.matching_texts):
                continue
            else:
                mask_params[k] = nn.Parameter(
                    torch.randn(min(v.shape), device=self.device, dtype=torch.bfloat16) * self.init_val,
                    requires_grad=True
                )
            self.num_params += mask_params[k].numel()

        self.learnable_params[self.new_mask_idx] = mask_params
        self.trainable_params[self.new_mask_idx] = list(mask_params.values())
        self.enable_mask.append(True)

        self.new_mask_idx += 1
        
        return True

    def get_trainable_parameters(self, mask_idx=None):
        if mask_idx is None:
            mask_idx = self.selected_mask

        return self.trainable_params[mask_idx]
    
    def save_weights(self, path):
        save_info = {
            "base_params": self.base_params, 
            "decomposed_params": self.decomposed_params,
            "learnable_params": self.learnable_params    
        }
        torch.save(save_info, path)
    
    def load_weights(self, path):
        load_info = torch.load(path)
        
        self.base_params = load_info["base_params"]
        self.decomposed_params = load_info["decomposed_params"]
        self.learnable_params = load_info["learnable_params"]
        
        self.trainable_params = {}
        self.enable_mask = []
        self.num_params = 0
        self.new_mask_idx = len(self.learnable_params)
        
        for mask_idx in self.learnable_params:
            self.trainable_params[mask_idx] = list(self.learnable_params[mask_idx].values())
            self.enable_mask.append(True)
            
            for k in self.learnable_params[mask_idx]: 
                self.num_params += self.learnable_params[mask_idx][k].numel()
            
        self.set_mask() # select the latest mask
    
    @property
    def class_map(self):
        """
        Print policy map
        """
        # inverse map
        inverse_map = {}
        for name in self.class_policy_map:
            mask_idx = self.class_policy_map[name]
            if mask_idx not in inverse_map:
                inverse_map[mask_idx] = []
            inverse_map[mask_idx].append(name)

        string = "CLASS MAP\n"
        string += "------------------\n"
        for mask_idx in inverse_map:
            string += f"{mask_idx}: {', '.join(inverse_map[mask_idx])}\n"
        string += "------------------\n"
        return string

    @property
    def latest_mask_idx(self):
        """
        return the latest mask
        """
        return self.new_mask_idx - 1

    @staticmethod
    def decompose(base_params, skip_match_texts = []):
        decomposed_params = {}
        for k, v in base_params.items():
            if any(text in k for text in skip_match_texts):
                continue  # skip this param

            U, S, V = torch.svd(v)
            decomposed_params[f"{k}.U"] = U
            decomposed_params[f"{k}.S"] = S
            decomposed_params[f"{k}.V"] = V
        return decomposed_params

    def activate_mask(self, p, mask_idx):
        """
        apply sigmoid on the mask
        """
        if self.enable_mask[mask_idx]:
            return torch.sigmoid(p).to(torch.bfloat16) * self.max_mult
        else:
            return torch.ones_like(p).to(torch.bfloat16)

    def forward(self, *args, **kwargs):
        return super().forward(*args, **kwargs)

    def compose_new_params(
        self,
        param_name,
        mask_idx,
    ):
        """
        Compose new parameters from decomposed parameters.
        """
        mm = self.activate_mask(self.learnable_params[mask_idx][param_name], mask_idx)
        return (
            self.decomposed_params[f"{param_name}.U"]
            @ torch.diag_embed(self.decomposed_params[f"{param_name}.S"] * mm)
            @ self.decomposed_params[f"{param_name}.V"].T
        ) * (
            self.decomposed_params[f"{param_name}.S"].sum()
            / (self.decomposed_params[f"{param_name}.S"] * mm).sum()
        )
    
    def toggle_mask(self, mask_value=True, mask_idx=None):
        """
        turn on or off the mask
        """    
        if mask_idx is None:
            mask_idx = self.latest_mask_idx

        self.enable_mask[mask_idx] = mask_value

        self.apply_policy_to_model(mask_idx)
    
    def update_backward(
        self,
        mask_idx = None,
    ):
        """
        backpropagate through the learnable parameter

        Then update the model weights
        """
        if mask_idx is None:
            mask_idx = self.selected_mask

        """Backward pass."""
        keys_to_backprop = [k for k in self.base_params if all(text not in k for text in self.matching_texts)]
        last_key = keys_to_backprop[-1]
        for k in keys_to_backprop[:-1]:
            self.compose_new_params(k, mask_idx).backward(
                self.get_parameter(k).grad, retain_graph=True
            )
        # release graph
        self.compose_new_params(last_key, mask_idx).backward(
            self.get_parameter(last_key).grad, retain_graph=False
        )

    def train(self, mask_idx=None):
        """
        Set the learnable parameters to training mode
        """
        if mask_idx is None:
            mask_idx = self.selected_mask

        # set all the other training parameters to false
        for k, p in self.named_parameters():
            self.get_parameter(k).requires_grad_(False)
        
        # set the learnable params to training
        for k in self.learnable_params[mask_idx]:
            self.get_parameter(k).requires_grad_(True)
            
    def infer(self):
        """
        Inference with routing
        """
        pass
            
    def apply_policy_to_model(self, mask_idx=None):
        """
        apply the weights on the model
        """
        if mask_idx is None:
            mask_idx = self.selected_mask

        updated_params = {}
        for k in self.base_params:
            if any(skip in k for skip in self.matching_texts):
                updated_params[k] = self.base_params[k]
                continue
            updated_params[k] = self.compose_new_params(k, mask_idx)
        self.load_state_dict(updated_params, strict=False)

if __name__ == "__main__":
    # patch the vision model
    vision_model = Chicken(model, device=device)
    model = vision_model

    def accuracy_fn(logits, labels):
        pred = torch.argmax(logits, dim=-1)
        return torch.mean((pred == labels).to(torch.float32))

    loss_fn = torch.nn.CrossEntropyLoss()

    ## LOADING THE DATASET

    from dataset_cifar import CIFAR100Dataset
    train_dataset = CIFAR100Dataset("Dataset/cifar-100-python", train=True, num_tasks=T)
    test_dataset = CIFAR100Dataset("Dataset/cifar-100-python", train=False, num_tasks=T)

    def collate_wrapper(batch):
        batch_inputs = {"pixel_values": []}
        batch_labels = []
        for data, label in batch:
            image = Image.fromarray(data)
            inputs = processor(images=image, return_tensors="pt")
            batch_labels.append(label)
            
            batch_inputs["pixel_values"].append(inputs["pixel_values"])
        batch_inputs['pixel_values'] = torch.cat(batch_inputs["pixel_values"], dim=0)
        batch_labels = torch.tensor(batch_labels)
        return batch_inputs, batch_labels


    ## TESTING THE SWITCHING OF MASK
    
    print(" =================== TESTING SWITCHING MASK ===================")

    num_epochs = 10

    for t in range(2):
        train_dataset.set_task(t)
        test_dataset.set_task(t)
        train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=False, collate_fn=collate_wrapper)
        test_dataloader = DataLoader(test_dataset, batch_size=4, shuffle=False, collate_fn=collate_wrapper)
        
        # RESET OF THE POLICY  
        torch.manual_seed(0)
        random.seed(0)
        
        # update the vision model
        model.add_class(class_names=[str(x) for x in train_dataset.subset_classes])
        model.set_mask()

        model.apply_policy_to_model()
        model.train()

        print(model.class_map)

        optimizer = torch.optim.Adam(model.get_trainable_parameters(), lr=1e-2)

        optimizer.zero_grad()
        model.zero_grad(set_to_none=True)
        
        for epoch in range(num_epochs):
            avg_acc = 0
            avg_loss = 0
            num = 0
            for data, labels in train_dataloader:
                labels = labels.to(device)
                for key in data:
                    data[key] = data[key].to(device)
                    
                logits = model(**data)
                probs = torch.nn.functional.softmax(logits, dim=-1)
            
                loss = loss_fn(probs, labels)
                acc = accuracy_fn(logits, labels)
                
                # Backward            
                optimizer.zero_grad()
                loss.backward()
                model.update_backward()
                optimizer.step()
                
                model.apply_policy_to_model()
                
                num += 1
                avg_acc += acc.item()
                avg_loss += loss.item()
                
                
                print(f"(TRAINING) task: {t}, epoch: {epoch}/{num_epochs}, loss: {avg_loss/num:.4f}, acc: {avg_acc/num:.4f}")
                
        model.save_weights("test.pt")
                
    # TESTING THE LOADING
    
    print(" =================== TESTING LOADING WEIGHTS ===================")
    
    model.load_weights("test.pt")

    avg_acc = 0
    avg_loss = 0
    num = 0
    for data, labels in train_dataloader:
        labels = labels.to(device)
        for key in data:
            data[key] = data[key].to(device)
            
        logits = model(**data)
        probs = torch.nn.functional.softmax(logits, dim=-1)

        loss = loss_fn(probs, labels)
        acc = accuracy_fn(logits, labels)
        
        num += 1
        avg_acc += acc.item()
        avg_loss += loss.item()
        print(f"(TESTING) task: {t}, loss: {avg_loss/num:.4f}, acc: {avg_acc/num:.4f}")
        
    # TESTING DISABLE MASK
    
    print(" =================== TESTING DISABLE MASK ===================")

    avg_acc = 0
    avg_loss = 0
    num = 0
    model.toggle_mask(False) # disable latest mask
    for data, labels in train_dataloader:
        labels = labels.to(device)
        for key in data:
            data[key] = data[key].to(device)
            
        logits = model(**data)
        probs = torch.nn.functional.softmax(logits, dim=-1)

        loss = loss_fn(probs, labels)
        acc = accuracy_fn(logits, labels)
        
        num += 1
        avg_acc += acc.item()
        avg_loss += loss.item()
        print(f"(TESTING) task: {t}, loss: {avg_loss/num:.4f}, acc: {avg_acc/num:.4f}")