
from typing import List
import torch
from torch.optim import Adam
from torch import nn
from torch.nn.parallel import DistributedDataParallel

from .basics import GetAttr
from .callback.core import * 
from .callback.tracking import * 
from .callback.scheduler import *
from .utils import *
from pathlib import Path
from tqdm import tqdm

import numpy as np

# from sklearn.base import BaseEstimator
from unittest.mock import patch


class Learner(GetAttr):

    def __init__(self, dls, model, 
                        loss_func=None, 
                        lr=1e-3, 
                        cbs=None, 
                        metrics=None, 
                        opt_func=Adam,
                        **kwargs):
                
        self.model, self.dls, self.loss_func, self.lr = model, dls, loss_func, lr
        self.opt_func = opt_func
        self.opt = self.opt_func(self.model.parameters(), self.lr) 
        
        self.metrics = metrics
        if self.dls: self.n_inp = self.dls.train.dataset.n_inp
        # Initialize callbacks                 
        if cbs and not isinstance(cbs, List): cbs = [cbs]    
        self.initialize_callbacks(cbs)        
        # Indicator of running lr_finder
        self.run_finder = False


    def default_callback(self):
        "get a set of default callbacks"
        default_cbs = [ SetupLearnerCB(), TrackTimerCB(), 
                        TrackTrainingCB(train_metrics=True, valid_metrics=True)]                  
        return default_cbs
    

    def initialize_callbacks(self, cbs):        
        default_cbs = self.default_callback()       
        self.cbs = update_callbacks(cbs, default_cbs) if cbs else default_cbs        
        # add print CB
        self.cbs += [PrintResultsCB()]        
        for cb in self.cbs: cb.learner = self     
        self('init_cb')       


    def add_callback(self, cb):                
        if not cb: return
        cb.learner = self
        self.cbs = update_callback(cb, self.cbs)           

    def add_callbacks(self, cbs):        
        if not isinstance(cbs, list):  cbs = [cbs]
        for cb in cbs: self.add_callback(cb)

    def remove_callback(self, cb): 
        cb.learn = None
        self.cbs, removed_cb = remove_callback(cb, self.cbs)
        return removed_cb
        

    def fit(self, n_epochs, lr=None, cbs=None, do_valid=True):
        " fit the model "
        self.n_epochs = n_epochs
        if cbs: self.add_callbacks(cbs)
        if lr: self.lr = lr        
        self('before_fit')
        try:
            for self.epoch in range(n_epochs):            
                self('before_epoch')                     
                self.one_epoch(train=True)            
                if self.dls.valid:                    
                    if do_valid: self.one_epoch(train=False)                
                self('after_epoch')        
        except KeyboardInterrupt: pass 
        self('after_fit')

    def one_epoch(self, train):                           
        self.epoch_train() if train else self.epoch_validate()        

    def epoch_train(self):
        self('before_epoch_train')
        self.model.train()                
        self.dl = self.dls.train
        self.all_batches('train')
        self('after_epoch_train')
    
    def epoch_validate(self, dl=None):
        self('before_epoch_valid')
        # model at evaluation mode  
        self.model.eval()                
        self.dl = dl if dl else self.dls.valid
        if self.dl:        
            with torch.no_grad(): self.all_batches('valid')
        self('after_epoch_valid')


    def all_batches(self, type):
        # for self.num,self.batch in enumerate(progress_bar(dl, leave=False)):        
        for num, batch in enumerate(self.dl):            
            self.iter, self.batch = num, batch        
            if type == 'train': self.batch_train()
            elif type == 'valid': self.batch_validate()
            elif type == 'predict': self.batch_predict()              

    def batch_train(self):
        self('before_batch_train')
        self._do_batch_train()
        self('after_batch_train')  

    def batch_validate(self):
        self('before_batch_valid')
        self._do_batch_validate()
        self('after_batch_valid')  
    
    def batch_predict(self, model=None):
        if not model: model=self.model
        self('before_batch_predict')
        self._do_batch_predict(model)
        self('after_batch_predict') 
    
    def _do_batch_train(self):        
        # get the inputs        
        xb, yb = self.batch                      
        # zero the parameter gradients
        self.opt.zero_grad()
        # forward + get loss + backward + optimize
        self.pred = self.model(xb)                
        self.loss = self.loss_func(self.pred, yb)                 
        # gradient
        self.loss.backward()
        # update weights
        self.opt.step()        
              

    def _do_batch_validate(self):        
        # get the inputs        
        xb, yb = self.batch          
        # forward
        self.pred = get_model(self.model)(xb)
        # calculate loss
        if yb is not None:
            self.loss = self.loss_func(self.pred, yb)       

    def _do_batch_predict(self, model):        
        # get the inputs        
        xb, yb = self.batch          
        # forward
        # self.pred = get_model(self.model)(xb)        
        self.pred = get_model(model)(xb)        
                
    
    def _predict(self, dl=None):
        # self('before_validate')
        self('before_predict')
        if dl is None: return
        self.dl = dl
        self.n_inp = dl.dataset.n_inp                
        self.model.eval()        #  model at evaluation mode  
        with torch.no_grad(): self.all_batches('predict')        
        self('after_predict')


    def predict(self, test_data, Dataset=None, Dataloader=None, batch_size=16):
        """_summary_
        Args:
            test_data can be a tensor, numpy array, dataset or dataloader
        Returns:
            _type_: _description_
        """                
        cb = GetPredictionsCB()
        self.add_callback(cb)                    
        test_dl = self._prepare_data(test_data, Dataset, Dataloader, batch_size)
        self._predict(test_dl)        
        self.preds = cb.preds
        return self.preds 
   


    def _prepare_data(self, test_data, Dataset=None, Dataloader=None, batch_size=16):
        #if not test_data: return test_data
        if type(test_data) == torch.utils.data.dataloader.DataLoader: return test_data
        if Dataset and Dataloader:
            test_dset = Dataset(test_data)
            test_dl = Dataloader(test_dset, batch_size)        
        else:            
            if self.dls: 
                # add test_data to the dataloader defined in the dls.train
                test_dl = self.dls.add_dl(test_data, batch_size=batch_size)  
            else: test_dl = test_data       # assume test_data is already a form of dataloader
        return test_dl
   
    
    def get_layer_output(self, inp, layers=None, unwrap=False):
        """
        Args:
            inp: can be numpy array, torch tensor or dataloader
        """
        self.model.eval()
        device = next(self.model.parameters()).device
        if isinstance(inp, np.ndarray): inp = torch.Tensor(inp).to(device)
        if isinstance(inp, torch.Tensor): inp = inp.to(device)
        
        return get_layer_output(inp, model=self.model, layers=layers, unwrap=unwrap)
    

    def fine_tune(self, n_epochs, base_lr=None, freeze_epochs=1, pct_start=0.3):
        if not base_lr: base_lr = self.lr
        print('Finetune the head')
        self.freeze()
        self.fit_one_cycle(freeze_epochs, lr_max=base_lr, pct_start=pct_start)
        print('Finetune the entire network')        
        self.unfreeze()
        self.fit_one_cycle(n_epochs, lr_max=base_lr/2, pct_start=pct_start)


    def fit_one_cycle(self, n_epochs, lr_max=None, pct_start=0.3):
        self.n_epochs = n_epochs        
        self.lr_max = lr_max if lr_max else self.lr
        cb = OneCycleLR(lr_max=self.lr_max, pct_start=pct_start)
        self.fit(self.n_epochs, lr=self.lr_max, cbs=cb)                
         

    def lr_finder(self, start_lr=1e-7, end_lr=10, num_iter=100, step_mode='exp', show_plot=True, suggestion='valley'):                
        n_epochs = num_iter//len(self.dls.train) + 1
        # indicator of lr_finder method is applied
        self.run_finder = True
        # add LRFinderCB to callback list and will remove later
        cb = LRFinderCB(start_lr, end_lr, num_iter, step_mode, suggestion=suggestion)                
        # fit           
        self.fit(n_epochs=n_epochs, cbs=cb, do_valid=False)        
        # should remove LRFinderCB callback after fitting        
        self.remove_callback(cb)        
        if show_plot: cb.plot_lr_find()
        if suggestion: return cb.suggested_lr  
        self.run_finder = False

    def freeze(self):
        " freeze the model head "
        if hasattr(get_model(self.model), 'head'): 
            # print('model head is available')
            for param in get_model(self.model).parameters(): param.requires_grad = False        
            for param in get_model(self.model).head.parameters(): param.requires_grad = True
            # print('model is frozen except the head')
            
            
    def unfreeze(self):
        for param in get_model(self.model).parameters(): param.requires_grad = True        


    def __call__(self, name):        
        for cb in self.cbs: 
            attr = getattr(cb, name)
            if attr is not None: attr()
          

    def save(self, file, path, **kwargs):
        "Save model and optimizer state (if `with_opt`) to `self.path/file`"
        file = join_path_file(file, path, ext='.pth')        
        save_model(file, self.model, getattr(self,'opt',None), **kwargs)
        return file


    def load(self, file, with_opt=False, device='cuda', strict=True, **kwargs):
        load_model(file, self.model, self.opt, with_opt, device=device, strict=strict)


    # def get_params(self, deep=True, **kwargs):
    #     params = BaseEstimator.get_params(self, deep=deep, **kwargs)
    #     return params

    # def _get_param_names(self):
    #     return (k for k in self.__dict__ if not k.endswith('_'))


    # def set_params(self, **kwargs):
    #     params = {}
    #     for key, val in kwargs.items():
    #         params[key] = val
    #     BaseEstimator.set_params(self, **params)





def save_model(file, model, opt, with_opt=True, pickle_protocol=2):
    "Save `model` to `file` along with `opt` (if available, and if `with_opt`)"
    if opt is None: with_opt=False
    state = get_model(model).state_dict()
    if with_opt: state = {'model': state, 'opt':opt.state_dict()}
    torch.save(state, file, pickle_protocol=pickle_protocol)


def load_model(file, model, opt=None, with_opt=False, device='cpu', strict=True):
    " load the saved model "
    state = torch.load(file, map_location=device)
    if not opt: with_opt=False
    model_state = state['model'] if with_opt else state
    get_model(model).load_state_dict(model_state, strict=strict)
    if with_opt: opt.load_state_dict(state['opt'])
    model = model.to(device)
      

def join_path_file(file, path, ext=''):
    "Return `path/file` if file is a string or a `Path`, file otherwise"
    if not isinstance(file, (str, Path)): return file
    if not isinstance(path, Path): path = Path(path)
    path.mkdir(parents=True, exist_ok=True)
    return path/f'{file}{ext}'


def get_model(model):
    "Return the model maybe wrapped inside `model`."    
    return model.module if isinstance(model, (DistributedDataParallel, nn.DataParallel)) else model


def transfer_weights(weights_path, model, exclude_head=True, device='cpu'):
    # state_dict = model.state_dict()
    new_state_dict = torch.load(weights_path, map_location=device)
    matched_layers = 0
    unmatched_layers = []
    for name, param in model.state_dict().items():        
        if exclude_head and 'head' in name: continue
        if name in new_state_dict:            
            matched_layers += 1
            input_param = new_state_dict[name]
            if input_param.shape == param.shape: param.copy_(input_param)
            else: unmatched_layers.append(name)
        else:
            unmatched_layers.append(name)
            pass # these are weights that weren't in the original model, such as a new head
    if matched_layers == 0: raise Exception("No shared weight names were found between the models")
    else:
        if len(unmatched_layers) > 0:
            print(f'check unmatched_layers: {unmatched_layers}')
        else:
            print(f"weights from {weights_path} successfully transferred!\n")
    model = model.to(device)


def update_callback(cb, list_cbs):
    for cb_ in list_cbs:
        if type(cb_) ==  type(cb): list_cbs.remove(cb_)
    list_cbs += [cb]
    return list_cbs

def update_callbacks(list_cbs, default_cbs):
    for cb in list_cbs: default_cbs = update_callback(cb, default_cbs)
    return default_cbs

def remove_callback(cb, list_cbs):
    for cb_ in list_cbs:
        if type(cb_) ==  type(cb):             
            list_cbs.remove(cb_)
            break
    return list_cbs, cb_


def get_layer_output(inp, model, layers=None, unwrap=False):
    """
    layers is a list of module names
    """
    orig_model = model
    
    if unwrap: model = unwrap_model(model)
    if not layers: layers = list(dict(model.named_children()).keys())
    if not isinstance(layers, list): layers = [layers]

    activation = {}
    def getActivation(name):
        # the hook signature
        def hook(model, input, output):
            activation[name] = output.detach().cpu().numpy()
        return hook

    # register forward hooks on the layers of choice    
    h_list = [getattr(model, layer).register_forward_hook(getActivation(layer)) for layer in layers]
    
    model.eval()
    out = orig_model(inp)    
    for h in h_list: h.remove()
    return activation