from tqdm import tqdm
from .models.bigproto import BigProtonet
import torch
from protonets.data.base import CudaTransform



class Engine(object):
    def __init__(self):
        hook_names = ['on_start', 'on_start_epoch', 'on_sample', 'on_forward',
                      'on_backward', 'on_end_epoch', 'on_update', 'on_end']

        self.hooks = { }
        for hook_name in hook_names:
            self.hooks[hook_name] = lambda state: None
    
    def _get_optimizer_params(self, model, state):
        parameters_to_optimize = list(model.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
            
        parameters_to_optimize = [
            {
                    'params': [p for n, p in model.encoder.named_parameters() 
                    if not any(nd in n for nd in no_decay)],
                    **state['optim_config']
            },
            {
                    'params': [p for n, p in model.encoder.named_parameters() 
                    if any(nd in n for nd in no_decay)],
                    'weight_decay': 0.0,
                    'lr': state['optim_config']['lr']
            },
            
        ]

        parameters_to_optimize_for_big_protos = []
        if hasattr(model, "proto_param"):
            # parameters_to_optimize.append(
            #         {
            #             # center_d parameters
            #             'params': [p for p in model.proto_param.center_d.parameters()],
            #             **state['optim_config']
            #         }
            #     )
            parameters_to_optimize_for_big_protos = [
                    {
                            'params': [p for p in model.proto_param.parameters()],
                            **state['proto_optim_config']
                    }
            ]
        return parameters_to_optimize, parameters_to_optimize_for_big_protos

    def _update_params(self, state):
        # state['optimizer'].param_groups[-1]['params'] = [p for p in state['model'].proto_param.center_d.parameters()]
        state['proto_optimizer'].param_groups[-1]['params'] = [p for p in state['model'].proto_param.parameters()]

    def train(self, **kwargs):
        state = {
            'model': kwargs['model'],
            'loader': kwargs['loader'],
            'optim_method': kwargs['optim_method'],
            'proto_optim_method': kwargs['proto_optim_method'],
            'optim_config': kwargs['optim_config'],
            'proto_optim_config': kwargs['proto_optim_config'],
            'max_epoch': kwargs['max_epoch'],
            'epoch': 0, # epochs done so far
            't': 0, # samples seen so far
            'batch': 0, # samples seen in current epoch
            'stop': False,
            'model_name': kwargs['model_name'],
            'warm_epoch': 20,
            'data.cuda': kwargs['data_cuda']
        }
        cudatrans = CudaTransform()

        parameters, parameters_for_proto = self._get_optimizer_params(state['model'], state)

        if 'resnet' in state['model_name']:
            state['optimizer'] = state['optim_method'](parameters, momentum=0.9)
        else:
            state['optimizer'] = state['optim_method'](parameters)

        state['proto_optimizer'] = None
        if isinstance(state['model'], BigProtonet):
            state['proto_optimizer'] = state['proto_optim_method'](parameters_for_proto)

        self.hooks['on_start'](state)
        while state['epoch'] < state['max_epoch'] and not state['stop']:
            state['model'].train()

            self.hooks['on_start_epoch'](state)

            state['epoch_size'] = len(state['loader'])

            for sample in tqdm(state['loader'], desc="Epoch {:d} train".format(state['epoch'] + 1)):
                # print(sample.keys())
                # print(len(sample['class']))
                # print(sample['xs'].dtype)
                # print(sample['xq'].size())
                state['sample'] = sample
                self.hooks['on_sample'](state)

                state['optimizer'].zero_grad()
                if state['proto_optimizer']:
                    state['proto_optimizer'].zero_grad()

                if state['data.cuda']:
                    cudatrans(state['sample'])
                loss, state['output'] = state['model'].loss(state['sample'])
                del state['sample']
                # print(self.hooks['on_forward'])
                self.hooks['on_forward'](state)

                loss.backward()
                self.hooks['on_backward'](state)
                ###############################################
                # classes = list(set(sample['class']))
                # print('old')
                # proto_r = []
                # grad = []
                # for k in classes:
                #     proto_r.append(state['model'].proto_param.radius_d[k])
                #     grad.append(state['model'].proto_param.radius_d[k].grad)
                # print('old proto r:', torch.hstack(proto_r))
                # print('grad:', torch.hstack(grad))
                ###############################################
                # print(state['optimizer'])
                state['optimizer'].step()
                if state['proto_optimizer']:
                    self._update_params(state)
                    state['proto_optimizer'].step()
                ################################################
                # print('updated')
                # proto_r = []
                # for k in classes:
                #     proto_r.append(state['model'].proto_param.radius_d[k])
                # print('proto r:', torch.hstack(proto_r))
                ###############################################
                state['t'] += 1
                state['batch'] += 1
                self.hooks['on_update'](state)

            state['epoch'] += 1
            state['batch'] = 0
            self.hooks['on_end_epoch'](state)
            # print(len(state['model'].proto_param.radius_d))
            # print(len(set(list(state['model'].proto_param.radius_d.keys()))))
            # print(list(state['model'].proto_param.radius_d.keys()))
            if hasattr(state['model'], 'proto_param'):
                proto_r = []
                for k in state['model'].proto_param.radius_d:
                    proto_r.append(state['model'].proto_param.radius_d[k])
                print('proto r:', torch.hstack(proto_r))
                del proto_r

        self.hooks['on_end'](state)
