import os
import time
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from collections import OrderedDict
import getpass
import random
# from tensorboardX import SummaryWriter
from .utils import (
    AverageMeter,
    accuracy, 
    validate,
    adjust_learning_rate,
    save_checkpoint,
    load_checkpoint,
    log_msg,
)
from .dot import DistillationOrientedTrainer
from mdistiller.dist.log import SmoothedValue, MetricLogger
from torchvision.transforms import Normalize


class BaseTrainer(object):
    def __init__(self, experiment_name, distiller, train_loader, val_loader, cfg, args):
        self.cfg = cfg
        self.args = args
        self.JPEG_enable = args.JPEG_enable
        self.distiller = distiller
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.optimizer = self.init_optimizer(cfg)
        self.best_acc = -1
        username = getpass.getuser()
        # init loggers
        self.log_path = os.path.join(cfg.LOG.PREFIX, experiment_name)

        if self.args.distributed:
            self.set_randomSeed_DA()
        else:
            self.seed_offset = random.randint(-2**31, (2**31 - 1) - int(2 ** 6))

        if self.args.distributed:
            if self.args.gpu == 0:
                if not os.path.exists(self.log_path):
                    os.makedirs(self.log_path)
        else:
            if not os.path.exists(self.log_path):
                    os.makedirs(self.log_path)


    def set_randomSeed_DA(self):
        # Initialize seed_offset
        self.seed_offset = 0
        # Step 1: Generate seed_offset on rank 0 (assuming self.args.gpu is the rank of the current process)
        if self.args.gpu == 0:
            self.seed_offset = random.randint(-2**31, (2**31 - 1) - int(2 ** 6))
            # print(f"Generated seed: {self.seed_offset}")  # Log for rank 0
        # Step 2: Create a tensor from the seed_offset (even on non-root ranks)
        self.seed_offset_tensor = torch.tensor([self.seed_offset], dtype=torch.int64)
        # Step 3: Move the tensor to the GPU corresponding to the current rank
        torch.cuda.set_device(self.args.gpu)  # Set the device for the current GPU
        self.seed_offset_tensor = self.seed_offset_tensor.cuda()  # Move to the correct GPU
        # Step 4: Broadcast the tensor across all ranks
        torch.distributed.broadcast(self.seed_offset_tensor, src=0)
        # Step 5: Retrieve the seed on all ranks (now all ranks will have the same seed)
        self.seed_offset = self.seed_offset_tensor.item()
        # Step 6: Use the seed_offset in your dataset or other components
        self.train_loader.dataset.seed_offset = self.seed_offset
        # Log the final seed on all ranks
        print(log_msg(f"Rank {self.args.gpu} - Seed Offset: {self.train_loader.dataset.seed_offset}", "INFO"))


    def init_optimizer(self, cfg):
        if self.args.distributed:
            if cfg.SOLVER.TYPE == "SGD":
                optimizer = optim.SGD(
                    self.distiller.get_learnable_parameters(),
                    lr=cfg.SOLVER.LR,
                    momentum=cfg.SOLVER.MOMENTUM,
                    weight_decay=cfg.SOLVER.WEIGHT_DECAY,)
            else:
                raise NotImplementedError(cfg.SOLVER.TYPE)
        else:
            if cfg.SOLVER.TYPE == "SGD":
                optimizer = optim.SGD(
                    self.distiller.module.get_learnable_parameters(),
                    lr=cfg.SOLVER.LR,
                    momentum=cfg.SOLVER.MOMENTUM,
                    weight_decay=cfg.SOLVER.WEIGHT_DECAY,)
            else:
                raise NotImplementedError(cfg.SOLVER.TYPE)
        return optimizer


    def log(self, lr, epoch, log_dict):
        # tensorboard log
        # for k, v in log_dict.items():
        #     self.tf_writer.add_scalar(k, v, epoch)
        # self.tf_writer.flush()
        # wandb log
        # if self.cfg.LOG.WANDB:
        #     import wandb
        #     wandb.log({"current lr": lr})
        #     wandb.log(log_dict)
        if log_dict["test_acc"] > self.best_acc:
            self.best_acc = log_dict["test_acc"]
            # if self.cfg.LOG.WANDB:
            #     wandb.run.summary["best_acc"] = self.best_acc
        
        # worklog.txt
        if self.args.distributed: 
            if self.args.gpu == 0:
                with open(os.path.join(self.log_path, "worklog.txt"), "a") as writer:
                    lines = ["-" * 25 + os.linesep, "epoch: {}".format(epoch) + os.linesep, "lr: {:.4f}".format(float(lr)) + os.linesep, "best_acc: {:.2f}".format(float(self.best_acc)) + os.linesep,]
                    for k, v in log_dict.items():
                        lines.append("{}: {:.2f}".format(k, v) + os.linesep)
                    lines.append("-" * 25 + os.linesep)
                    writer.writelines(lines)
        else:
            with open(os.path.join(self.log_path, "worklog.txt"), "a") as writer:
                    lines = ["-" * 25 + os.linesep, "epoch: {}".format(epoch) + os.linesep, "lr: {:.4f}".format(float(lr)) + os.linesep, "best_acc: {:.2f}".format(float(self.best_acc)) + os.linesep,]
                    for k, v in log_dict.items():
                        lines.append("{}: {:.2f}".format(k, v) + os.linesep)
                    lines.append("-" * 25 + os.linesep)
                    writer.writelines(lines)


    def train(self, resume=False):
        epoch = 1
        if resume:
            state = load_checkpoint(os.path.join(self.log_path, "latest"))
            epoch = state["epoch"] + 1
            self.distiller.load_state_dict(state["model"])
            self.optimizer.load_state_dict(state["optimizer"])
            self.best_acc = state["best_acc"]
        while epoch < self.cfg.SOLVER.EPOCHS + 1:
            self.train_epoch(epoch)
            epoch += 1
            self.train_loader.dataset.seed_offset += 1
            print(log_msg(f"Rank {self.args.gpu} - Seed Offset: {self.train_loader.dataset.seed_offset}", "INFO"))
        
        print(log_msg("Best accuracy:{}".format(self.best_acc), "EVAL"))

        if self.args.distributed: 
            if self.args.gpu == 0:
                with open(os.path.join(self.log_path, "worklog.txt"), "a") as writer:
                    writer.write("best_acc\t" + "{:.2f}".format(float(self.best_acc)))
        else:
            with open(os.path.join(self.log_path, "worklog.txt"), "a") as writer:
                writer.write("best_acc\t" + "{:.2f}".format(float(self.best_acc)))
        
        with open(os.path.join(self.cfg.LOG.PREFIX, self.cfg.EXPERIMENT.TAG, self.cfg.EXPERIMENT.PROJECT, "results.txt"), "a") as writer:
            writer.write("{}: {:.4f}\n".format(self.args.experiment_name, float(self.best_acc)))


    def train_epoch(self, epoch, log_freq=200):
        lr = adjust_learning_rate(epoch, self.cfg, self.optimizer)
        
        metric_logger_train = MetricLogger(delimiter="  ")
        header = f"Train [{epoch}]:"
        metric_logger_train.add_meter('losses', SmoothedValue(window_size=10, fmt='{value:.3f}'))
        metric_logger_train.add_meter('s_top1', SmoothedValue(window_size=10, fmt='{value:.3f}'))
        metric_logger_train.add_meter('s_top5', SmoothedValue(window_size=10, fmt='{value:.3f}'))
        metric_logger_train.add_meter('t_top1', SmoothedValue(window_size=10, fmt='{value:.3f}'))
        metric_logger_train.add_meter('t_top5', SmoothedValue(window_size=10, fmt='{value:.3f}'))
        
        # train loops
        self.distiller.train()
        if self.args.train_mode or self.cfg.DISTILLER.TYPE in ["WTTM"]:
            if hasattr(self.distiller, 'module'):
                self.distiller.module.set_teacher_train() # DataParallel case
            else:
                self.distiller.set_teacher_train()
            print("Teacher is on TRAIN mode")
        else:
            print("Teacher is on EVAL mode")
            
        for idx, data in enumerate(metric_logger_train.log_every(self.train_loader, log_freq, header)):
            self.train_iter(data, epoch, metric_logger_train)
        metric_logger_train.synchronize_between_processes()

        metric_logger_val = MetricLogger(delimiter="  ")
        metric_logger_val.add_meter('top1', SmoothedValue(window_size=10, fmt='{value:.3f}'))
        metric_logger_val.add_meter('top5', SmoothedValue(window_size=10, fmt='{value:.3f}'))
        metric_logger_val.add_meter('test_loss', SmoothedValue(window_size=10, fmt='{value:.3f}'))

        # validate
        validate_metric_logger(self.val_loader, self.distiller, metric_logger_val)
        metric_logger_val.synchronize_between_processes()
        test_acc = metric_logger_val.top1.global_avg
        msg = f"s_top1_val {metric_logger_val.top1.global_avg} \t s_top1_train {metric_logger_train.s_top1.global_avg:.4f} \t t_top1_train {metric_logger_train.t_top1.global_avg:.4f}"
        print(log_msg(msg, f"EVAL"))

        # log
        log_dict = OrderedDict({"train_acc": metric_logger_train.s_top1.global_avg,
                                "train_loss": metric_logger_train.losses.global_avg,
                                "test_acc": metric_logger_val.top1.global_avg,
                                "test_acc_top5": metric_logger_val.top5.global_avg,
                                "test_loss": metric_logger_val.test_loss.global_avg,})
        self.log(lr, epoch, log_dict)
        
        # saving checkpoint
        state = {"epoch": epoch,
                 "model": self.distiller.state_dict(),
                 "optimizer": self.optimizer.state_dict(),
                 "best_acc": self.best_acc,}
        
        if self.args.distributed: 
            if self.args.gpu == 0:
                student_state = {"model": self.distiller.student.state_dict()}
                save_checkpoint(state, os.path.join(self.log_path, "latest.pth"))
                save_checkpoint(student_state, os.path.join(self.log_path, "student_latest.pth"))
                # if epoch % self.cfg.LOG.SAVE_CHECKPOINT_FREQ == 0:
                #     save_checkpoint(state, os.path.join(self.log_path, "epoch_{}.pth".format(epoch)))
                #     save_checkpoint(student_state, os.path.join(self.log_path, "student_{}.pth".format(epoch)),)
                # update the best
                # if test_acc >= self.best_acc:
                #     save_checkpoint(state, os.path.join(self.log_path, "best.pth"))
                #     save_checkpoint(student_state, os.path.join(self.log_path, "student_best.pth"))
        else:
            student_state = {"model": self.distiller.module.student.state_dict()}
            save_checkpoint(state, os.path.join(self.log_path, "latest.pth"))
            save_checkpoint(student_state, os.path.join(self.log_path, "student_latest.pth"))
            # if epoch % self.cfg.LOG.SAVE_CHECKPOINT_FREQ == 0:
            #     save_checkpoint(state, os.path.join(self.log_path, "epoch_{}.pth".format(epoch)))
            #     save_checkpoint(student_state, os.path.join(self.log_path, "student_{}.pth".format(epoch)),)
            # update the best
            # if test_acc >= self.best_acc:
            #     save_checkpoint(state, os.path.join(self.log_path, "best.pth"))
            #     save_checkpoint(student_state, os.path.join(self.log_path, "student_best.pth"))


    def train_iter(self, data, epoch, train_meters):
        self.optimizer.zero_grad()
        image, target, index = data
        image = image.float()
        image = image.cuda(non_blocking=True)
        target = target.cuda(non_blocking=True)
        index = index.cuda(non_blocking=True)

        # forward
        # image [0,255]->[0,1]->[-2,2]
        if self.JPEG_enable:
            Scale2One = Normalize(mean=[0, 0, 0], std=[255., 255., 255.])
            if self.cfg.DATASET.TYPE == "imagenet":
                normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            elif self.cfg.DATASET.TYPE == "tiny_imagenet":
                normalize = Normalize(mean=[0.4802, 0.4481, 0.3975], std=[0.2302, 0.2265, 0.2262])
            else:
                normalize = Normalize(mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761])
            logits_student, logits_teacher, losses_dict = self.distiller(image=normalize(Scale2One(image)), perturbedInput=image, target=target, epoch=epoch)
        else:
            logits_student, logits_teacher, losses_dict = self.distiller(image=image, perturbedInput=image, target=target, epoch=epoch)
            
        # backward
        loss = sum([l.mean() for l in losses_dict.values()])
        loss.backward()
        self.optimizer.step()
        
        # collect info
        batch_size = image.size(0)
        acc1, acc5 = accuracy(logits_student, target, topk=(1, 5))        
        train_meters.meters["losses"].update(loss.cpu().detach().numpy().mean().item(), batch_size)
        train_meters.meters["s_top1"].update(acc1[0].item(), batch_size)
        train_meters.meters["s_top5"].update(acc5[0].item(), batch_size)

        acc1, acc5 = accuracy(logits_teacher, target, topk=(1, 5))       
        train_meters.meters["t_top1"].update(acc1[0].item(), batch_size)
        train_meters.meters["t_top5"].update(acc5[0].item(), batch_size)


class CRDTrainer(BaseTrainer):
    def train_iter(self, data, epoch, train_meters):
        self.optimizer.zero_grad()
        image, target, index, contrastive_index = data
        image = image.float()
        image = image.cuda(non_blocking=True)
        target = target.cuda(non_blocking=True)
        index = index.cuda(non_blocking=True)
        contrastive_index = contrastive_index.cuda(non_blocking=True)

        # forward
        if self.JPEG_enable:
            Scale2One = Normalize(mean=[0, 0, 0], std=[255., 255., 255.])
            if self.cfg.DATASET.TYPE == "imagenet":
                normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            elif self.cfg.DATASET.TYPE == "tiny_imagenet":
                normalize = Normalize(mean=[0.4802, 0.4481, 0.3975], std=[0.2302, 0.2265, 0.2262])
            else:
                normalize = Normalize(mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761])
            logits_student, logits_teacher, losses_dict = self.distiller(image=normalize(Scale2One(image)), perturbedInput=image, target=target, index=index, contrastive_index=contrastive_index)
        else:
            logits_student, logits_teacher, losses_dict = self.distiller(image=image, perturbedInput=image, target=target, index=index, contrastive_index=contrastive_index)

        # backward
        loss = sum([l.mean() for l in losses_dict.values()])
        loss.backward()
        self.optimizer.step()

        # collect info
        batch_size = image.size(0)
        acc1, acc5 = accuracy(logits_student, target, topk=(1, 5))        
        train_meters.meters["losses"].update(loss.cpu().detach().numpy().mean().item(), batch_size)
        train_meters.meters["s_top1"].update(acc1[0].item(), batch_size)
        train_meters.meters["s_top5"].update(acc5[0].item(), batch_size)

        acc1, acc5 = accuracy(logits_teacher, target, topk=(1, 5))       
        train_meters.meters["t_top1"].update(acc1[0].item(), batch_size)
        train_meters.meters["t_top5"].update(acc5[0].item(), batch_size)


class DOT(BaseTrainer):
    def init_optimizer(self, cfg):
        if cfg.SOLVER.TYPE == "SGD":
            m_task = cfg.SOLVER.MOMENTUM - cfg.SOLVER.DOT.DELTA
            m_kd = cfg.SOLVER.MOMENTUM + cfg.SOLVER.DOT.DELTA
            optimizer = DistillationOrientedTrainer(
                self.distiller.module.get_learnable_parameters(),
                lr=cfg.SOLVER.LR,
                momentum=m_task,
                momentum_kd=m_kd,
                weight_decay=cfg.SOLVER.WEIGHT_DECAY,)
        else:
            raise NotImplementedError(cfg.SOLVER.TYPE)
        return optimizer


    def train(self, resume=False):
        epoch = 1
        if resume:
            state = load_checkpoint(os.path.join(self.log_path, "latest"))
            epoch = state["epoch"] + 1
            self.distiller.load_state_dict(state["model"])
            self.optimizer.load_state_dict(state["optimizer"])
            self.best_acc = state["best_acc"]
        while epoch < self.cfg.SOLVER.EPOCHS + 1:
            self.train_epoch(epoch)
            epoch += 1
            self.train_loader.dataset.seed_offset += 1
            print(log_msg(f"Rank {self.args.gpu} - Seed Offset: {self.train_loader.dataset.seed_offset}", "INFO"))

        print(log_msg("Best accuracy:{}".format(self.best_acc), "EVAL"))
        with open(os.path.join(self.log_path, "worklog.txt"), "a") as writer:
            writer.write("best_acc\t" + "{:.2f}".format(float(self.best_acc)))


    def train_iter(self, data, epoch, train_meters):
        image, target, index = data
        image = image.float()
        image = image.cuda(non_blocking=True)
        target = target.cuda(non_blocking=True)
        index = index.cuda(non_blocking=True)

        # forward
        if self.JPEG_enable:
            Scale2One = Normalize(mean=[0, 0, 0], std=[255., 255., 255.])
            if self.cfg.DATASET.TYPE == "imagenet":
                normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            elif self.cfg.DATASET.TYPE == "tiny_imagenet":
                normalize = Normalize(mean=[0.4802, 0.4481, 0.3975], std=[0.2302, 0.2265, 0.2262])
            else:
                normalize = Normalize(mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761])
            logits_student, logits_teacher, losses_dict = self.distiller(image=normalize(Scale2One(image)), perturbedInput=image, target=target, epoch=epoch)
        else:
            logits_student, logits_teacher, losses_dict = self.distiller(image=image, perturbedInput=image, target=target, epoch=epoch)

        # dot backward
        loss_ce, loss_kd = losses_dict['loss_ce'].mean(), losses_dict['loss_kd'].mean()
        self.optimizer.zero_grad(set_to_none=True)
        loss_kd.backward(retain_graph=True)
        self.optimizer.step_kd()
        self.optimizer.zero_grad(set_to_none=True)
        loss_ce.backward()
        self.optimizer.step()
        
        # collect info
        batch_size = image.size(0)
        acc1, acc5 = accuracy(logits_student, target, topk=(1, 5))        
        train_meters.meters["losses"].update((loss_ce+loss_kd).cpu().detach().numpy().mean().item(), batch_size)
        train_meters.meters["s_top1"].update(acc1[0].item(), batch_size)
        train_meters.meters["s_top5"].update(acc5[0].item(), batch_size)

        acc1, acc5 = accuracy(logits_teacher, target, topk=(1, 5))       
        train_meters.meters["t_top1"].update(acc1[0].item(), batch_size)
        train_meters.meters["t_top5"].update(acc5[0].item(), batch_size)


class CRDDOT(BaseTrainer):
    def init_optimizer(self, cfg):
        if cfg.SOLVER.TYPE == "SGD":
            m_task = cfg.SOLVER.MOMENTUM - cfg.SOLVER.DOT.DELTA
            m_kd = cfg.SOLVER.MOMENTUM + cfg.SOLVER.DOT.DELTA
            optimizer = DistillationOrientedTrainer(
                self.distiller.module.get_learnable_parameters(),
                lr=cfg.SOLVER.LR,
                momentum=m_task,
                momentum_kd=m_kd,
                weight_decay=cfg.SOLVER.WEIGHT_DECAY,)
        else:
            raise NotImplementedError(cfg.SOLVER.TYPE)
        return optimizer


    def train(self, resume=False):
        epoch = 1
        if resume:
            state = load_checkpoint(os.path.join(self.log_path, "latest"))
            epoch = state["epoch"] + 1
            self.distiller.load_state_dict(state["model"])
            self.optimizer.load_state_dict(state["optimizer"])
            self.best_acc = state["best_acc"]
        while epoch < self.cfg.SOLVER.EPOCHS + 1:
            self.train_epoch(epoch)
            epoch += 1
            self.train_loader.dataset.seed_offset += 1
            print(log_msg(f"Rank {self.args.gpu} - Seed Offset: {self.train_loader.dataset.seed_offset}", "INFO"))
        print(log_msg("Best accuracy:{}".format(self.best_acc), "EVAL"))
        with open(os.path.join(self.log_path, "worklog.txt"), "a") as writer:
            writer.write("best_acc\t" + "{:.2f}".format(float(self.best_acc)))


    def train_iter(self, data, epoch, train_meters):
        self.optimizer.zero_grad()
        image, target, index, contrastive_index = data
        image = image.float()
        image = image.cuda(non_blocking=True)
        target = target.cuda(non_blocking=True)
        index = index.cuda(non_blocking=True)
        contrastive_index = contrastive_index.cuda(non_blocking=True)

        # forward
        if self.JPEG_enable:
            Scale2One = Normalize(mean=[0, 0, 0], std=[255., 255., 255.])
            if self.cfg.DATASET.TYPE == "imagenet":
                normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            elif self.cfg.DATASET.TYPE == "tiny_imagenet":
                normalize = Normalize(mean=[0.4802, 0.4481, 0.3975], std=[0.2302, 0.2265, 0.2262])
            else:
                normalize = Normalize(mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761])
            logits_student, logits_teacher, losses_dict = self.distiller(image=normalize(Scale2One(image)), perturbedInput=image, target=target, index=index, contrastive_index=contrastive_index)
        else:
            logits_student, logits_teacher, losses_dict = self.distiller(image=image, perturbedInput=image, target=target, index=index, contrastive_index=contrastive_index)

        # dot backward
        loss_ce, loss_kd = losses_dict['loss_ce'].mean(), losses_dict['loss_kd'].mean()
        self.optimizer.zero_grad(set_to_none=True)
        loss_kd.backward(retain_graph=True)
        self.optimizer.step_kd()
        self.optimizer.zero_grad(set_to_none=True)
        loss_ce.backward()
        # self.optimizer.step((1 - epoch / 240.))
        self.optimizer.step()

        # collect info
        batch_size = image.size(0)
        acc1, acc5 = accuracy(logits_student, target, topk=(1, 5))        
        train_meters.meters["losses"].update((loss_ce+loss_kd).cpu().detach().numpy().mean().item(), batch_size)
        train_meters.meters["s_top1"].update(acc1[0].item(), batch_size)
        train_meters.meters["s_top5"].update(acc5[0].item(), batch_size)

        acc1, acc5 = accuracy(logits_teacher, target, topk=(1, 5))       
        train_meters.meters["t_top1"].update(acc1[0].item(), batch_size)
        train_meters.meters["t_top5"].update(acc5[0].item(), batch_size)


def validate_metric_logger(val_loader, distiller, metric_logger, log_freq=50):
    criterion = nn.CrossEntropyLoss()
    header = f"EVAL:"
    distiller.eval()
    with torch.no_grad():
        for idx, data in enumerate(metric_logger.log_every(val_loader, log_freq, header)):
            image, target = data
            image = image.float()
            image = image.cuda(non_blocking=True)
            target = target.cuda(non_blocking=True)
            output = distiller(image=image)
            loss = criterion(output, target)
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            
            batch_size = target.size(0)
            metric_logger.meters["top1"].update(acc1[0].item(), batch_size)
            metric_logger.meters["top5"].update(acc5[0].item(), batch_size)
            metric_logger.meters["test_loss"].update(loss.cpu().detach().numpy().mean().item(), batch_size)

