from copy import copy, deepcopy
import json
import os

import torch.nn as nn

from datasets import get_dataset

from typing import Any

from utils import binary_to_boolean_type

from datasets.utils.continual_dataset import ContinualDataset
from models.utils.continual_model import ContinualModel
from utils.args import ArgumentParser
import torch.func as func
import gc
import open_clip # type: ignore

from typing import Tuple, List

import matplotlib.pyplot as plt
plt.ioff()

import torch
from models.clip_ft_utils.utils import add_clip_args
from models.clip_ft_utils.backbone import Backbone, create_clip
from models.clip_ft_utils.backbone import build_classification_head
from models.clip_ft_utils.utils import get_parameter
from models.clip_ft_utils.utils import OptimizerBuilder
from models.clip_ft_utils.utils import compute_acc_on_last_task

from models.clip_ft_utils.merging import get_merging_function
from models.clip_ft_utils.merging import add_merging_args

import wandb


class CLIPTauJp(ContinualModel):
    """STATIC Continual Learning with CLIP"""
    NAME = 'clip_ft_taujp'
    COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual']
    net: Backbone

    @staticmethod
    def get_parser(parser) -> ArgumentParser:
        add_clip_args(parser)
        add_merging_args(parser)
        parser.add_argument('--save_task_vectors',  type=binary_to_boolean_type, default=0)

        parser.add_argument('--tangent',  type=binary_to_boolean_type, default=1)
        parser.add_argument('--chunks', type=int, default=1, help='chose how many chunks for vitual batch size')

        parser.add_argument('--use_lora', type=binary_to_boolean_type, default=0)

        parser.add_argument('--scheduler_ntk', type=str, default='cosine_plus',
                            choices=["none", "cosine", "cosine_plus", "decay", "step"])

        parser.add_argument('--reg_lambda', type=float, default=0.0)
        parser.add_argument('--penalty_ideal', type=binary_to_boolean_type, default=0)
        parser.add_argument('--clip_grad_norm', type=float, default=None, required=False)
        parser.add_argument('--penalty_imagenet', type=binary_to_boolean_type, default=0)
        

        return parser

    def __init__(self, backbone, loss, args, transform, dataset):
        assert dataset is not None

        _, train_preprocess, val_preprocess = \
            open_clip.create_model_and_transforms(args.clip_backbone, pretrained='openai', device=torch.device('cpu'))

        clip_model = create_clip(args.clip_backbone, torch.device(args.device))

        super().__init__(clip_model, loss, args, transform, dataset=dataset) # type: ignore

        self.net = Backbone(clip_model, dataset, args)
        self.param_names = [name for name, _ in self.net.visual_encoder.named_parameters()]

        for name, param in self.net.named_parameters():
            param.requires_grad = False

        if self.args.save_task_vectors:
            task_vector_path = "checkpoints/tau_j_p"
            os.makedirs(task_vector_path, exist_ok=True)
            with open(f"{task_vector_path}/{self.args.conf_jobnum}_{dataset.NAME}_args.json", "w") as f:
                json_args = deepcopy(vars(args))
                del json_args['device'] # device not serializable because torch people are dumb
                json.dump(json_args, f)

        torch.backends.cuda.enable_mem_efficient_sdp(False)

        clip_model = clip_model.to(dtype=torch.float32)
        clip_model.eval()

        self.clip_model = clip_model
        self.clip_transform = train_preprocess
        self.clip_eval_transform = val_preprocess

        self.optimizer_builder = OptimizerBuilder(cmd_args=self.args)

        self.cur_offset = None
        self.cls_head: nn.Module = None # type: ignore

        self.delta_w_dict: dict[str, Any] = None # type: ignore
        self.delta_w_names: list[str] = None # type: ignore
        self.delta_w_shapes: dict[str, Any] = None # type: ignore

        self.scheduler1 = None
        self.num_batches = 0

        self.merging = get_merging_function(self.args, self.device)

        self.merged_task_vector = []
        
        self.num_total_tasks: int = dataset.N_TASKS # type: ignore
        self.dataset_name: int = dataset.NAME # type: ignore

        self.individual_acc, self.individual_mask_acc = [], []
        self.norm_acc, self.norm_mask_acc = [], []

        self.train_loaders = []
        self.task_id_pointer = 0
        if self.args.penalty_imagenet:
            control_args = deepcopy(self.args)
            control_args.dataset = 'seq-imagenet1k' 
            control_args.data_path = "/home/aba/thomas/FFTMammoth/data/imagenet1k" 
            control_args.n_experiences = 1
            control_args.seed = 0
            control_args.transform_type = 'weak'
            joint_dataset = get_dataset(control_args)
            train_dl, _ = joint_dataset.get_data_loaders()
            train_dl.dataset.transform = self.clip_transform
            self.train_loaders.append(train_dl)
            pass
        else:
            joint_dataset = get_dataset(self.args)
            for _ in range(self.num_total_tasks):
                train_dl, _ = joint_dataset.get_data_loaders()
                train_dl.dataset.transform = self.clip_transform
                self.train_loaders.append(train_dl)

    def create_param_like(self, param, requires_grad):
        return [torch.zeros_like(param, dtype = torch.float32, requires_grad = requires_grad, device = self.args.device)]

    def create_lora_param_like(self, fin, fout, requires_grad, r1=None, r2=None):
        r1 = 16 if r1 is None else r1
        r2 = 16 if r2 is None else r2
        config = ('kaiming', 'zeros')
        return get_parameter((fout, r2), self.device, config[1], False, requires_grad), \
            get_parameter((r1, fin), self.device, config[0], False, requires_grad)

    def begin_task(self, dataset):
        torch.cuda.empty_cache()
        dataset.test_loaders[-1].dataset.transform = self.clip_eval_transform
        dataset.train_loader.dataset.transform = self.clip_transform # type: ignore

        self.cur_offset = self.compute_offsets(self.current_task)

        if isinstance(dataset.N_CLASSES_PER_TASK, int):
            self.cpt = dataset.N_CLASSES_PER_TASK
        else:
            self.cpt = dataset.N_CLASSES_PER_TASK[-1]

        if self.current_task != 0:
            self.net.task_id += 1

        self.cls_head = build_classification_head(self.clip_model, dataset, self.cur_offset)

        print("\nRELOADING CLIP VISUAL ENCODER")
        self.net.copy_visual_encoder(self.clip_model)

        for param in self.net.visual_encoder.parameters():
            param.requires_grad = False

        print("\nCLIP VISUAL ENCODER RELOADED\n\n")

        self.delta_w_dict = {}
        self.delta_w_shapes = {}

        for name, param in self.net.visual_encoder.named_parameters():

            self.delta_w_shapes[name] = param.shape

            if self.args.use_lora and len(param.shape) == 2:
                fout, fin = param.shape[0], param.shape[1]

                if "mlp" in name:
                    B, A = self.create_lora_param_like(fin, fout, self.args.ft_linears == 1)
                    self.delta_w_dict[name] = [B, A]
                elif "attn" in name:
                    B, A = self.create_lora_param_like(fin, fout, self.args.ft_attention == 1, r1=16*3, r2=16*3)
                    self.delta_w_dict[name] = [B, A]
                elif "proj" in name:
                    if name == 'proj':
                        # skip, this is the projection layer of the visual encoder which has beeen replaced
                        continue
                    B, A = self.create_lora_param_like(fin, fout, self.args.ft_proj == 1)
                    self.delta_w_dict[name] = [B, A]
            else:
                if "mlp" in name:
                    self.delta_w_dict[name] = self.create_param_like(param, requires_grad=self.args.ft_linears==1)
                elif "attn" in name:
                    self.delta_w_dict[name] = self.create_param_like(param, requires_grad=self.args.ft_attention==1)
                elif "proj" in name:
                    if name == 'proj':
                        # skip, this is the projection layer of the visual encoder which has beeen replaced
                        continue
                    self.delta_w_dict[name] = self.create_param_like(param, requires_grad=self.args.ft_proj==1)
                elif "ln" in name:
                    self.delta_w_dict[name] = self.create_param_like(param, requires_grad=self.args.ft_ln==1)
                elif "class_embed" in name:
                    self.delta_w_dict[name] = self.create_param_like(param, requires_grad=self.args.ft_class_embed==1)
                elif "conv" in name:
                    self.delta_w_dict[name] = self.create_param_like(param, requires_grad=self.args.ft_conv==1)
                elif "positional_embedding" in name:
                    self.delta_w_dict[name] = self.create_param_like(param, requires_grad=self.args.ft_pos_embed==1)

        self.delta_w_names = list(self.delta_w_dict.keys())

        all_params = [p for param_list in self.delta_w_dict.values() for p in param_list]
        num_batches: int = len(dataset.train_loader)  # type: ignore

        self.opt, self.scheduler1 = self.optimizer_builder.build_opt_and_sched(all_params, num_batches)

        self.task_id_pointer = 1 if (self.current_task == 0 and not self.args.penalty_imagenet) else 0

        self.train()

    def get_parameter_from_dict(self, name):
        assert name in self.delta_w_names
        list_params = self.delta_w_dict[name]
        if len(list_params) == 1:
            return list_params[0]
        elif len(list_params) == 2:
            return list_params[0] @ list_params[1]
        else:
            raise ValueError

    def get_all_parameters_from_dict(self):
        return [self.get_parameter_from_dict(k) for k in self.delta_w_names]

    def end_task(self, dataset: ContinualDataset) -> None: #TODO  set the model in eval mode

        print(f"Current task: {self.current_task}")

        self.eval()

        self.merged_task_vector = []

        for i, key in enumerate(self.delta_w_names):
            self.merged_task_vector.append(torch.clone(self.get_parameter_from_dict(key)))

        actual_seen_classes = self.n_seen_classes

        self.cls_head = build_classification_head(self.clip_model, dataset, self.cur_offset, all_heads=True)
        self._n_seen_classes = dataset.N_CLASSES

        acc, acc_mask_classes = compute_acc_on_last_task(self, dataset)
        self.individual_acc.append(acc)
        self.individual_mask_acc.append(acc_mask_classes)

        self._n_seen_classes = actual_seen_classes

        if self.args.save_task_vectors:
            task_vector_path = f"checkpoints/tau_j_p/{self.args.conf_jobnum}_{dataset.NAME}_task_{self.current_task}.pt"
            os.makedirs(os.path.dirname(task_vector_path), exist_ok=True)
            torch.save(self.merged_task_vector, task_vector_path)
            torch.save(self.cls_head.state_dict(), task_vector_path.replace('.pt', '_cls_head.pt'))
            torch.save({
                'delta_w_names': self.delta_w_names
                }, task_vector_path.replace('.pt', '_meta.pt'))
            print(f"Task vector saved to {task_vector_path}")

        del self.merged_task_vector[:]
        del self.merged_task_vector

        self.cls_head = build_classification_head(self.clip_model, dataset, self.cur_offset, eval=True)

        for i, key in enumerate(self.delta_w_names):
            num_params = len(self.delta_w_dict[key])
            for p_l in range(num_params):
                self.delta_w_dict[key][p_l].requires_grad = False

        self.merging.add({
            key: self.get_parameter_from_dict(key) for key in self.delta_w_names
        })

        self.merged_task_vector = self.merging.merge(self.delta_w_names)

        self.opt.zero_grad() # type: ignore
        self.opt = None

        self.net.copy_visual_encoder(self.clip_model)

        torch.cuda.empty_cache()

        del self.opt, self.scheduler1, self.delta_w_dict
        gc.collect()

        return super().end_task(dataset)

    def end_eval(self, dataset: ContinualDataset, accs: Tuple[List, List]) -> None:

        def safe_den(y, eps=1e-8):
            return y if abs(y) >= eps else y + eps

        self.norm_acc = [acc / safe_den(self.individual_acc[t])
                         for t, acc in enumerate(accs[0])]
        self.norm_mask_acc = [acc / safe_den(self.individual_mask_acc[t])
                              for t, acc in enumerate(accs[1])]

        wandb.log({
            "RESULT_mean_norm_acc": sum(self.norm_acc) / len(self.norm_acc),
            "RESULT_mean_norm_mask_acc": sum(self.norm_mask_acc) / len(self.norm_mask_acc),
            "Task": self.current_task
        })

    def penalty_weight(self):

        dl_train = self.train_loaders[self.task_id_pointer]
        train_iter = iter(dl_train)

        try:
            data = next(train_iter)
        except StopIteration:
            assert False

        inputs = data[0]
        inputs = inputs.to(self.device)

        forward_fun = self.create_functional(inputs, self.delta_w_names)
        params = [param for name, param in self.net.visual_encoder.named_parameters() if name in self.delta_w_names]
        _, jvp = func.jvp(forward_fun, (tuple(params),), (tuple(self.get_all_parameters_from_dict()),), )

        return torch.norm(jvp, dim=1).mean()

    def create_functional(self, inputs, delta_names):
        def func_network(param_values):
            param = {name: param for name, param in zip(delta_names, param_values)}
            features = func.functional_call(self.net.visual_encoder, param, inputs) # type: ignore
            return nn.functional.normalize(features, dim=-1)
        return func_network

    def increment_task_pointer(self):
        if self.args.penalty_imagenet:
            return
        
        self.task_id_pointer += 1
        if self.task_id_pointer == self.current_task:
            self.task_id_pointer += 1
        if self.task_id_pointer >= self.N_TASKS:
            self.task_id_pointer = 1 if self.current_task == 0 else 0
        

    def observe(self, inputs, labels, not_aug_inputs, epoch=None):

        if self.args.tangent:
            forward_fun = self.create_functional(inputs, self.delta_w_names)
            params = [param for name, param in self.net.visual_encoder.named_parameters() if name in self.delta_w_names] 
            image_features, jvp = func.jvp(forward_fun, (tuple(params),), (tuple(self.get_all_parameters_from_dict()),),) # type: ignore
            image_features = image_features + jvp

        else:
            tunable_params = [p for n, p in self.net.visual_encoder.named_parameters() if n in self.delta_w_names]
            dict_param = {name: param + net_param for name, param, net_param in
                          zip(self.delta_w_names, self.get_all_parameters_from_dict(), tunable_params)}

            image_features = func.functional_call(self.net.visual_encoder, dict_param, inputs) # type: ignore
            image_features = nn.functional.normalize(image_features, dim=-1)
        
        similarity = self.cls_head(image_features)
        loss_task = self.loss(similarity, labels - self.n_past_classes)
        loss = loss_task / self.args.chunks

        loss.backward()

        if not self.args.penalty_ideal:
            loss_penalty = self.penalty_weight() / self.args.chunks
            (self.args.reg_lambda * loss_penalty).backward()
        else:
            for _ in range(self.n_tasks-1):
                loss_penalty = self.penalty_weight() / self.args.chunks
                (self.args.reg_lambda * loss_penalty).backward()
                self.increment_task_pointer()

        if (self.task_iteration > 0) and self.task_iteration % self.args.chunks == 0:
            if self.scheduler1:
                self.scheduler1(self.task_iteration // self.args.chunks)
            if self.args.clip_grad_norm is not None and self.args.clip_grad_norm > 0:
                torch.nn.utils.clip_grad_norm_(
                    (p for group in self.opt.param_groups for p in group['params']),
                    self.args.clip_grad_norm
                )
            self.opt.step() # type: ignore
            self.opt.zero_grad() # type: ignore

            if not self.args.penalty_ideal:
                self.increment_task_pointer()

        return loss.item()

    @torch.no_grad()
    def forward(self, x):

        if self.args.tangent:
            forward_fun = self.create_functional(x, self.delta_w_names)
            params = [param for name, param in self.net.visual_encoder.named_parameters() if name in self.delta_w_names]
            image_features, jvp = func.jvp(forward_fun, (tuple(params),),  # type: ignore
                                           (tuple(self.merged_task_vector),), )
            image_features = image_features + jvp
        else:
            tunable_params = {n:p for n, p in self.net.visual_encoder.named_parameters() if n in self.delta_w_names}

            dict_param = {}
            for i, key in enumerate(self.delta_w_names):
                dict_param[key] = tunable_params[key] + self.merged_task_vector[i]

            image_features = func.functional_call(self.net.visual_encoder, dict_param, x) # type: ignore
            image_features = nn.functional.normalize(image_features, dim=-1)
            
        similarity = self.cls_head(image_features)
        return similarity[:, :self.n_seen_classes]

    def get_debug_iters(self):
        return 5