import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, List, Literal, Optional, Tuple, Union
import copy
import os


class ContinualOCL(nn.Module):
    def __init__(
        self,
        net,
        num_task: int = 2,
        isolation: bool = None,
        isolation_parameters: List[str] = []
    ):
        super().__init__()

        if isolation:
            assert len(isolation_parameters) != 0

        self.net = net
        # self.net_without_ddp = net
        # if self.net.__getattr__('module'):
        #     self.net_without_ddp = net.module

        self.num_task = num_task
        self._current_task = 0


        self.children_list = []
        for name, module in self.net.named_children():
            self.children_list.append(name)

        self.isolation = isolation
        self.isolation_parameters = isolation_parameters
        # for prarm_name in isolation_parameters:
        #     assert prarm_name in self.children_list, f"Invalid isolation parameter! {prarm_name} not in the model"

        self.checkpoints = {}
        self.checkpoint = None

        self.current_epoch = 0


    def current_task(self):
        return self._current_task
    

    def last_tast(self):
        return self._current_task -1 if self._current_task -1 > 0 else 0
    
    
    def set_grads(self, new_grads: torch.Tensor) -> None:
        progress = 0
        for pp in list(self.net.parameters()):
            cand_grads = new_grads[progress: progress +
                                   torch.tensor(pp.size()).prod()].view(pp.size())
            progress += torch.tensor(pp.size()).prod()
            pp.grad = cand_grads


    def get_grads(self) -> torch.Tensor:
        grads = []
        for pp in list(self.net.parameters()):
            grads.append(pp.grad.view(-1))
        return torch.cat(grads)


    def get_params(self) -> torch.Tensor:
        return torch.nn.utils.parameters_to_vector(self.net.parameters())
    

    def set_params(self, new_params: torch.Tensor) -> None:
        torch.nn.utils.vector_to_parameters(new_params, self.net.parameters())


    def get_parameters(self):
        return self.net.parameters()
    
    
    def get_checkpoint(self, module=None):
        if module is not None:
            return copy.deepcopy(getattr(self.net, module, None).state_dict()) 
        else:
            return copy.deepcopy(self.net.state_dict()) 

    
    def update_checkpoint(self, task_num: int =-1):
        if task_num == -1:
            task_num = self.current_task()

        weight = {}
        for child in self.children_list:
            weight.update({child: self.get_checkpoint(module=child)})

        self.checkpoints.update({task_num: weight,})
        print()

    
    def reload_checkpoint(self, task_num: int =-1, checkpoint_path=''):

        
        if task_num == -1 or task_num not in self.checkpoints:
            task_num = self.current_task()
        if checkpoint_path == '':
            for name, module in self.checkpoints[task_num].items():
                if len(module) > 0:
                    print(f'- Used Parameter for Training: Task {task_num} {name}')
                    self.net.__getattr__(name).load_state_dict(module, strict=True)
                else:
                    print(f'- Use {name}')
        else:
            assert os.path.exists(checkpoint_path), f'No file at {checkpoint_path}'

            checkpoint = {k[len('net.'):]: v for k, v in torch.load(checkpoint_path, map_location='cpu')['model'].items()}
            self.net.load_state_dict(checkpoint, strict=True)

    
    def load_isolated_checkpoint(self, task_num: int =-1, checkpoint_path: str =''):
        assert self.isolation, f"Tring to load isolated paramenters, however isolation flag is {self.isolation}..."
        assert len(self.isolation_parameters) != 0, f"Tring to load isolated paramenters, however isolation parameters is {self.isolation_parameters}..."
        
        if task_num == -1 or task_num not in self.checkpoints:
            task_num = self.current_task()

        print(f'- Isloating parameters for task{task_num}\n{self.isolation_parameters}\n')

        if checkpoint_path == '':
            checkpoint = self.checkpoints[task_num]
            for prarm_name in self.isolation_parameters:
                print(f'- Parameter isoloation: from task {task_num} {prarm_name}')
                self.net.__getattr__(prarm_name).load_state_dict(checkpoint[prarm_name], strict=True)
        else:
            assert os.path.exists(checkpoint_path), f'No file at {checkpoint_path}'

            checkpoint = {k[len('net.'):]: v for k, v in torch.load(checkpoint_path, map_location='cpu')['model'].items()}
            new_checkpoint = {}
            for key, param in self.net.named_parameters():
                new_checkpoint.setdefault(key, param)
                for prarm_name in self.isolation_parameters:
                    if prarm_name in key:
                        new_checkpoint.update({key: checkpoint[key]})
        
            missing_params = self.net.load_state_dict(new_checkpoint, strict=False)
            print('Caution: Loading isolation', missing_params)


    def freeze_module(self, freeze_prams):
        pass
                        

    def begin_task_(self):
        task_idx = self.current_task()
        num_task = self.num_task
        assert task_idx < num_task, f"Current task({task_idx}) can't be larger than maximum number of task({num_task})..."
        print(f"\n-------------------------------------------------------------------------------")
        print(f"------------------------------ Starting task {task_idx}/{num_task} -----------------------------")
        print(f"-------------------------------------------------------------------------------\n")

    
    def begin_task(self, **kwargs):
        self.begin_task_()


    def ene_task_(self, x=None):
        task_idx = self.current_task()
        num_task = self.num_task
        print(f"\n---------------------------------End task {task_idx}/{num_task}---------------------------------\n\n\n")
        self._current_task += 1
        if self._current_task == num_task:
            print(f"\n------------------------------------------End training------------------------------------------\n\n\n")

    
    def end_task(self,  **kwargs):
        self.ene_task_()


    def inter_task(self, **kwargs):
        pass
        

    def forward_(self, x):
        return self.net(x)
    
    
    def forward_without_ddp(self, x):
        return self.net(x)


    def forward(self, x):
        return self.forward_(x)
    
