import os
import json
from functools import partial
from tqdm import tqdm

import numpy as np

import torch
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler 
import torchvision
import torchnet as tnt


from protonets.engine import Engine

import protonets.utils.data as data_utils
import protonets.utils.model as model_utils
import protonets.utils.log as log_utils
from torch.optim.lr_scheduler import _LRScheduler
import pickle


# SEED=0# 1234

class WarmUpLR(_LRScheduler):
    """warmup_training learning rate scheduler
    Args:
        optimizer: optimzier(e.g. SGD)
        total_iters: totoal_iters of warmup phase
    """
    def __init__(self, optimizer, total_iters, last_epoch=-1):
        
        self.total_iters = total_iters
        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        """we will use the first m batches, and set the learning
        rate to base_lr * m / total_iters
        """
        return [base_lr * self.last_epoch / (self.total_iters + 1e-8) for base_lr in self.base_lrs]

def main(opt):
    SEED = opt['seed']
    if not os.path.isdir(opt['log.exp_dir']):
        os.makedirs(opt['log.exp_dir'])

    # save opts
    with open(os.path.join(opt['log.exp_dir'], 'opt.json'), 'w') as f:
        json.dump(opt, f)
        f.write('\n')

    trace_file = os.path.join(opt['log.exp_dir'], 'trace.txt')

    # Postprocess arguments
    opt['model.x_dim'] = list(map(int, opt['model.x_dim'].split(',')))
    opt['log.fields'] = opt['log.fields'].split(',')

    torch.manual_seed(SEED)
    if opt['data.cuda']:
        torch.cuda.manual_seed(SEED)

    if opt['data.trainval']:
        data = data_utils.load(opt, ['trainval'])
        train_loader = data['trainval']
        val_loader = None
    else:
        data = data_utils.load(opt, ['train', 'val'])
        train_loader = data['train']
        val_loader = data['val']

    model = model_utils.load(opt)

    if opt['data.cuda']:
        model.cuda()

    engine = Engine()

    meters = { 'train': { field: tnt.meter.AverageValueMeter() for field in opt['log.fields'] } }

    if val_loader is not None:
        meters['val'] = { field: tnt.meter.AverageValueMeter() for field in opt['log.fields'] }    

    def on_start(state):
        if os.path.isfile(trace_file):
            os.remove(trace_file)
        # if 'resnet' in state['model_name']:
        #     state['scheduler'] = optim.lr_scheduler.StepLR(state['optimizer'], milestones=settings.MILESTONES, gamma=0.2)
        # else:
        state['scheduler'] = lr_scheduler.StepLR(state['optimizer'], opt['train.decay_every'], gamma=0.5)

        if 'resnet' in state['model_name']:
            state['warmup_scheduler'] = WarmUpLR(state['optimizer'], len(state['loader']) * state['warm_epoch'])

        # state['proto_scheduler'] = lr_scheduler.StepLR(state['proto_optimizer'], opt['train.decay_every'], gamma=0.5)
    engine.hooks['on_start'] = on_start

    def on_start_epoch(state):
        for split, split_meters in meters.items():
            for field, meter in split_meters.items():
                meter.reset()
        if 'resnet' in state['model_name']:
            if state['epoch'] > state['warm_epoch']:
                state['scheduler'].step()
        else:
            state['scheduler'].step()

        # state['proto_scheduler'].step()
    engine.hooks['on_start_epoch'] = on_start_epoch

    
    def on_update(state):
        for field, meter in meters['train'].items():
            meter.add(state['output'][field])
        if state.get('data_to_save', None) is None:
            state['data_to_save'] = {}
        state['data_to_save'][state['t']] = {
            'dists': state['output']['dists'].cpu().numpy(),
            'proto_radius': state['output']['proto_radius'].cpu().numpy(),
            'class': state['output']['class']
        }
        # if (state['t']+1) % 10 == 0:
        #     print('current proto radius:', state['output']['proto_radius'])
        #     print('this batch support mean distance:', state['output']['proto_radius_this_batch'])
        #     print('distance:', state['output']['dists'])
        #     print('distance between centers:', torch.pow(state['output']['proto_center']-state['output']['proto_center_this_batch'], 2).sum(-1))

    engine.hooks['on_update'] = on_update

    def on_forward(state):
        if 'resnet' in state['model_name']:
            if state['epoch'] <= state['warm_epoch']:
                state['warmup_scheduler'].step()
    engine.hooks['on_forward'] = on_forward

    def on_end_epoch(hook_state, state):
        if val_loader is not None:
            if 'best_loss' not in hook_state:
                hook_state['best_loss'] = np.inf
            if 'best_acc' not in hook_state:
                hook_state['best_acc'] = -np.inf
            if 'wait' not in hook_state:
                hook_state['wait'] = 0

        if val_loader is not None:
            model_utils.evaluate(state['model'],
                                 val_loader,
                                 meters['val'],
                                 desc="Epoch {:d} valid".format(state['epoch']))

        meter_vals = log_utils.extract_meter_values(meters)
        print("Epoch {:02d}: {:s}".format(state['epoch'], log_utils.render_meter_values(meter_vals)))
        meter_vals['epoch'] = state['epoch']
        with open(trace_file, 'a') as f:
            json.dump(meter_vals, f)
            f.write('\n')

        if val_loader is not None:
            # if meter_vals['val']['loss'] < hook_state['best_loss']:
            #     hook_state['best_loss'] = meter_vals['val']['loss']
            #     print("==> best model (loss = {:0.6f}), saving model...".format(hook_state['best_loss']))
            if meter_vals['val']['acc'] > hook_state['best_acc']:
                hook_state['best_acc'] = meter_vals['val']['acc']
                print("==> best model (acc = {:0.6f}), saving model...".format(hook_state['best_acc']))

                # save data
                with open('/data/private/chenyulin/BigPrototypes/CV/saved_data/data.pkl', 'wb')as f:
                    pickle.dump(state['data_to_save'], f)
                    print('data saved!')

                # save model
                state['model'].cpu()
                torch.save(state['model'], os.path.join(opt['log.exp_dir'], 'best_model.pt'))
                if opt['data.cuda']:
                    state['model'].cuda()

                hook_state['wait'] = 0
            else:
                hook_state['wait'] += 1

                if hook_state['wait'] > opt['train.patience']:
                    print("==> patience {:d} exceeded".format(opt['train.patience']))
                    state['stop'] = True
        else:
            state['model'].cpu()
            torch.save(state['model'], os.path.join(opt['log.exp_dir'], 'best_model.pt'))
            if opt['data.cuda']:
                state['model'].cuda()

    engine.hooks['on_end_epoch'] = partial(on_end_epoch, { })

    engine.train(
        model = model,
        loader = train_loader,
        optim_method = getattr(optim, opt['train.optim_method']),
        proto_optim_method = getattr(optim, opt['train.proto_optim_method']),
        optim_config = { 'lr': opt['train.learning_rate'],
                         'weight_decay': opt['train.weight_decay'] },
        max_epoch = opt['train.epochs'],
        proto_optim_config = {
            'lr': opt['train.proto_learning_rate'],
            'weight_decay':opt['train.proto_weight_decay']
        },
        model_name = opt['model.model_name'],
        data_cuda = opt['data.cuda']
    )
