import logging
import time

from data_helper import create_train_dataloaders

from model_helper import Actor

import os
import math
from utils import *
from lr_scheduler import *

import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel
import wandb

from tqdm import tqdm
from copy import deepcopy

logger = logging.getLogger(__name__)

# os.environ["WANDB_API_KEY"] = '55d19c4f9bcf11ab061685c2832b0cbfdda06e50'
os.environ["WANDB_MODE"] = "offline"

# team: tnl-dev
wandb.init(project="MPT", entity="tnl-dev", name='MPT',group="v0")

# personal: peterbishop
#wandb.init(project="SATracker", entity="tnl-peterbishop", name='SATracker',group="v7_test")


class Trainer(object):
    def __init__(self, cfg):
        super(Trainer, self).__init__()
        self.cfg = cfg

        if cfg.train.ddp.istrue:
            local_rank = cfg.train.ddp.local_rank
            self.device = torch.device("cuda", local_rank)
        else:
            self.device = torch.device("cuda", 0)

        self.actor = Actor(cfg, self.device)
        self.actor = self.actor.to(self.device)

        # 加载全部参数微调
        if cfg.train.resume:
            state = torch.load(cfg.train.resume, map_location="cuda:0")
            u, v = self.actor.net.load_state_dict(state['net'], strict=True)
            self.cfg.train.start_epoch = state['epoch']
        elif cfg.train.pretrain:
            state = torch.load(cfg.train.pretrain, map_location="cuda:0")
            new_state = deepcopy(state)
            # print(state.keys())
            for key in state.keys():
                if 'backbone' not in key:
                    del new_state[key]
            u, v = self.actor.net.load_state_dict(new_state, strict=False)
            print("load pretrain model success")
            
        
        if cfg.train.ddp.istrue:
            self.actor = DistributedDataParallel(self.actor, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)
            self.optimizer, self.scheduler = build_siamese_opt_lr(cfg, self.actor.module.net)
        else:
            self.optimizer, self.scheduler = build_siamese_opt_lr(cfg, self.actor.net)

        if cfg.train.resume:
            state = torch.load(cfg.train.resume, map_location="cuda:0")
            self.optimizer.load_state_dict(state['optimizer'])

        self.train_cfg = self.cfg.train

        self.step = 0
        self.start_time = time.time()

        self.train_loader = create_train_dataloaders(self.cfg)
  
    def run(self):
        best_iou = 0
        iou_epoch = 0
        best_loss = 1000
        loss_epoch = 0

        start_epoch = self.train_cfg.start_epoch
        end_epoch = self.train_cfg.end_epoch
        

        for epoch in range(start_epoch+1, end_epoch+1):

            if self.cfg.train.ddp.istrue:
                self.train_loader.sampler.set_epoch(epoch)

            avg_loss, avg_iou = self.train(epoch)

            if self.scheduler is not None:
                if self.train_cfg.lr.type != 'cosine':
                    self.scheduler.step()
                else:
                    self.scheduler.step(epoch - 1)

            if (epoch%self.cfg.train.save_steps==0) or (end_epoch-epoch<=50):
                if not self.cfg.train.ddp.istrue or dist.get_rank() == 0:
                    checkpoint_dir = f'{self.cfg.common.ckpt_dir}/model_epoch_{epoch}.bin'
                    self.save_checkpoint(epoch, checkpoint_dir)
                    print("save path ",checkpoint_dir)

            if not self.cfg.train.ddp.istrue or dist.get_rank() == 0:
                checkpoint_dir = f'{self.cfg.common.ckpt_dir}/model_last_epoch.bin'
                if os.path.exists(checkpoint_dir):
                    os.remove(checkpoint_dir)
                self.save_checkpoint(epoch, checkpoint_dir, last=True)

            if avg_loss<best_loss:
                best_loss = avg_loss 
                loss_epoch = epoch
            if avg_iou>best_iou:
                if not self.cfg.train.ddp.istrue or dist.get_rank() == 0:
                    checkpoint_dir = f'{self.cfg.common.ckpt_dir}/model_bestiou_[{iou_epoch}-{str(best_iou)}].bin'
                    if os.path.exists(checkpoint_dir):
                        os.remove(checkpoint_dir)

                    iou_epoch = epoch
                    best_iou = avg_iou 
                    checkpoint_dir = f'{self.cfg.common.ckpt_dir}/model_bestiou_[{iou_epoch}-{str(best_iou)}].bin'
                    self.save_checkpoint(epoch, checkpoint_dir)
                    print("save path ",checkpoint_dir)

            if not self.cfg.train.ddp.istrue or dist.get_rank() == 0:
                logger.info(f"Epoch {epoch} training over, avg_loss:{avg_loss:.5f}, avg_iou:{avg_iou:.5f}, best loss is {best_loss:.3f} from epoch {loss_epoch}, best iou is {best_iou:.3f} from epoch {iou_epoch}.")


            # break

        logger.info('\ntraining over')
        # logger.info('\ntraining over, best validation f1: {}'.format(best_score))

    def train(self, epoch):
        self.start_time = time.time()
    
        self.actor.train()
        self.step = 0
        logger.info(f"train_dataloader size is: {len(self.train_loader)}")

        avg_loss = 0
        avg_iou = 0
        num_total_steps = len(self.train_loader)

        average_loss = AverageMeter()
        average_giou_loss = AverageMeter()
        average_l1_loss = AverageMeter()
        average_IoU = AverageMeter()

        for batch in self.train_loader:

            # print(dist.get_rank())

            for key, value in batch.items():
                batch[key] = batch[key].to(self.device)

            loss, model_loss = self.actor(batch)

            giou_loss = model_loss['giou']
            l1_loss = model_loss['l1']
            IoU = model_loss['IoU']
        
            self.optimizer.zero_grad()
            loss.backward()
            if self.cfg.train.grad_clip_norm > 0:
                if not self.cfg.train.ddp.istrue:
                    torch.nn.utils.clip_grad_norm_(self.actor.net.parameters(), self.cfg.train.grad_clip_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(self.actor.module.net.parameters(), self.cfg.train.grad_clip_norm)
                    
            if self.is_valid_number(loss.item()):
                self.optimizer.step()

            self.step += 1   

            loss = loss.item()
            giou_loss = giou_loss
            l1_loss = l1_loss
            IoU = IoU

            batch_size = batch['template'].shape[0]

            average_loss.update(loss, batch_size)
            average_giou_loss.update(giou_loss, batch_size)
            average_l1_loss.update(l1_loss, batch_size)
            average_IoU.update(IoU, batch_size)                

            if self.step % self.cfg.train.print_freq == 0:
                if not self.cfg.train.ddp.istrue or dist.get_rank() == 0:
                    time_per_step = (time.time() - self.start_time) / max(1, self.step)
                    remaining_time = time_per_step * ((num_total_steps-self.step) + (self.train_cfg.end_epoch-epoch-1)*num_total_steps)
                    # remaining_time = time.strftime('%d:%H:%M:%S', time.gmtime(remaining_time))
                    remaining_time = seconds_to_dhms(int(remaining_time))
                    lr = get_lr(self.optimizer)

                    

                    metrics = {
                        "Loss/total": average_loss.avg, 
                        "Loss/giou_loss": average_giou_loss.avg, 
                        "Loss/l1_loss": average_l1_loss.avg,
                        "IoU": average_IoU.avg,
                    }

                    wandb.log(metrics)
                    

                    logger.info(f"Epoch {epoch} step {self.step}/{len(self.train_loader)} eta {remaining_time}: loss {average_loss.avg:.5f}, IoU {average_IoU.avg:.5f}, lr:{lr:.7f}")

            # if self.step==20:
            #     break

        if not self.cfg.train.ddp.istrue or dist.get_rank() == 0:
            metrics = {
                "Avg/Loss": average_loss.avg,
                "Avg/IoU": average_IoU.avg,
            }

            wandb.log(metrics)

        return round(average_loss.avg,5), round(average_IoU.avg,5)

    def save_checkpoint(self, epoch, checkpoint_dir, last=False):
        """Saves a checkpoint of the network and other variables."""

        net = self.actor.module.net if self.cfg.train.ddp.istrue else self.actor.net
        # net = self.actor.net

        actor_type = type(self.actor).__name__
        net_type = type(net).__name__
        state = {
            'epoch': epoch,
            'actor_type': actor_type,
            'net_type': net_type,
            'net': net.state_dict(),
            'net_info': getattr(net, 'info', None),
            'optimizer': None
        }
        if last:
            state['optimizer'] = self.optimizer.state_dict()

        torch.save(state, checkpoint_dir)

    def load_checkpoint(self, checkpoint_dir):
        """Saves a checkpoint of the network and other variables."""

        state = torch.load(checkpoint_dir, map_location="cuda:0")
        self.actor.net.load_state_dict(state['net'], strict=True)
        # self.model.load_state_dict(state)
        self.cfg.start_epoch = state['epoch']+1
        self.optimizer.load_state_dict(state['optimizer'])

    def is_valid_number(self, x):
        return not(math.isnan(x) or math.isinf(x) or x > 1e4)
