import os
import wandb
import torch
import logging
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch.distributed as dist

from classification.trainer.base import Trainer


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def is_ddp():
    return dist.is_available() and dist.is_initialized()

def is_main_process(args=None):
    return (not is_ddp()) or (dist.get_rank() == 0)


class JCGResNet_Cls_Trainer(Trainer):
    def __init__(self, args):
        super().__init__(args)


    def train(self):
        device = next(self.model.parameters()).device
        world_size = dist.get_world_size() if is_ddp() else 1

        logger.info("***********Fine Tuning***********")
        logger.info(" Num examples = %d", len(self.trainset))
        logger.info(" Num Epochs = %d", self.args.num_epoch)
        logger.info(" Batch size per GPU = %d", self.args.per_gpu_train_batch_size)
        logger.info(" Total train batch size = %d", self.args.per_gpu_train_batch_size * world_size)
        logger.info("Total optimization steps = %d", self.args.t_total)

        # loss logs
        tr_losses_, logging_losses_ = 0.0, 0.0
        best_acc = 0.0
        best_eval_results = {}
        last_eval_results = {}

        iteration_per_epoch = len(self.train_dataloader)

        self.model.zero_grad()

        if is_main_process(self.args):
            wandb.init(project=self.args.project_name, name=self.run_file, entity=self.args.entity)
        else:
            wandb.init(mode='disabled')

        for epoch in tqdm(range(self.args.num_epoch), desc="Epoch"):
            iteration = tqdm(self.train_dataloader, desc="Iteration") if is_main_process(
                self.args) else self.train_dataloader
            if is_ddp():
                self.train_sampler.set_epoch(epoch)

            for i, (data, class_label) in enumerate(iteration):
                batch = data.size(0)
                self.model.train()

                data = data.to(device, non_blocking=True).float().contiguous()
                class_label = class_label.to(device, non_blocking=True).contiguous()

                outputs = self.model(data, label=class_label)
                loss = outputs['loss']
                preds = outputs['pred']
                acc = (preds == class_label).float().mean().item()

                tr_losses_ += loss.item()

                loss.backward()
                self.optimizer.step()
                self.model.zero_grad()
                self.global_step += 1

                if self.args.max_steps > 0 and self.global_step >= self.args.max_steps:
                    wandb.finish()

                    return best_eval_results, last_eval_results

            avg_loss, acc_eval, class_accuracy = self.eval()

            if is_main_process(self.args):
                logger.info(f"Epoch {epoch + 1} Evaluation Accuracy: {acc_eval:.4f}")
                logger.info(f"Epoch {epoch + 1} Training Loss: {(tr_losses_ - logging_losses_)/self.args.logging_steps:.4f}")


                last_eval_results = {
                    'epoch': epoch + 1,
                    'loss': avg_loss,
                    'accuracy': acc_eval,
                    'class_accuracy': class_accuracy
                }


                if acc_eval > best_acc:
                    best_acc = acc_eval
                    logger.info(f"🎉 New best model found! Accuracy: {best_acc:.4f}. Saving model...")


                    best_eval_results = {
                        'epoch': epoch + 1,
                        'loss': avg_loss,
                        'accuracy': acc_eval,
                        'class_accuracy': class_accuracy
                    }

                    output_dir = os.path.join(
                        self.output_dir, self.args.model_type, self.save_file, "best_model"
                    )
                    os.makedirs(output_dir, exist_ok=True)

                    model_to_save = self.model.module if hasattr(self.model, "module") else self.model
                    torch.save(model_to_save.state_dict(), os.path.join(output_dir, "model.pt"))
                    torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
                    torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer_fc.pt"))
                    torch.save(self.scheduler.state_dict(), os.path.join(output_dir, "scheduler_fc.pt"))

                    logger.info(f"Best model checkpoint saved to {output_dir}")


                logs = {}
                logs['Train/00.Loss'] = (tr_losses_ - logging_losses_) / max(1, self.args.logging_steps)
                logging_losses_ = tr_losses_

                logs['Train/01.accuracy'] = acc
                lr_scalar = self.scheduler.get_last_lr()[0]
                logs['Train/learning_rate'] = lr_scalar

                if avg_loss is not None:
                    logs['Eval/00.Loss'] = avg_loss
                    logs['Eval/01.accuracy'] = acc_eval
                    logs.update(class_accuracy)

                try:
                    del data, class_label, outputs, loss, preds
                except Exception:
                    pass
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()

                wandb.log(logs)

            self.scheduler.step()


        logger.info("Training finished. The best model is saved in the 'best_model' directory.")
        wandb.finish()
        return best_eval_results, last_eval_results

    def eval(self):

        logger.info("*********** Evaluation ***********")
        logger.info("Evaluate {} dataset".format(self.args.dataset))
        logger.info(" Num examples = %d", len(self.testset))
        logger.info(" Num Epochs = %d", self.args.num_epoch)
        device = next(self.model.parameters()).device
        self.model.eval()

        loss_sum_local = torch.tensor(0.0, device=device)
        sample_cnt_local = torch.tensor(0.0, device=device)
        correct_local = torch.tensor(0.0, device=device)


        labels = self.get_labels_from_dataset(self.testset)
        num_classes = int(torch.max(labels).item()) + 1

        correct_pc_local = torch.zeros(num_classes, device=device)
        total_pc_local = torch.zeros(num_classes, device=device)

        with torch.no_grad():

            iterator = tqdm(self.test_dataloader, desc="Eval") if is_main_process(
                self.args) else self.test_dataloader
            for data, class_label in iterator:
                b = data.size(0)
                # data = data.to(device, non_blocking=True)
                # class_label = class_label.to(device, non_blocking=True)
                data = data.to(device, non_blocking=True).float().contiguous()
                class_label = class_label.to(device, non_blocking=True).contiguous()

                # with torch.cuda.amp.autocast(enabled=True):
                out = self.model(data, label=class_label)
                loss = out["loss"]
                preds = out["pred"]

                loss_sum_local += loss.detach() * b
                sample_cnt_local += b
                correct_local += (preds == class_label).sum()


                for c in range(num_classes):
                    m = (class_label == c)
                    n = m.sum()
                    if n > 0:
                        total_pc_local[c] += n
                        correct_pc_local[c] += (preds[m] == class_label[m]).sum()

        if is_ddp():
            dist.all_reduce(loss_sum_local, op=dist.ReduceOp.SUM)
            dist.all_reduce(sample_cnt_local, op=dist.ReduceOp.SUM)
            dist.all_reduce(correct_local, op=dist.ReduceOp.SUM)
            dist.all_reduce(correct_pc_local, op=dist.ReduceOp.SUM)
            dist.all_reduce(total_pc_local, op=dist.ReduceOp.SUM)

        if not is_main_process(self.args):
            return None, None, None

        if sample_cnt_local.item() <= 0:
            avg_loss = float("nan")
            acc = float("nan")
        else:
            avg_loss = (loss_sum_local / sample_cnt_local).item()
            acc = (correct_local / sample_cnt_local).item()

        class_accuracy = {}
        digits = max(1, len(str(num_classes - 1)))
        for c in range(num_classes):
            if total_pc_local[c].item() > 0:
                class_accuracy[f"Eval/class_{c:0{digits}d}_acc"] = (
                        correct_pc_local[c] / total_pc_local[c]
                ).item()

        return avg_loss, acc, class_accuracy


    def get_labels_from_dataset(self, dataset):
        if hasattr(dataset, 'tensors'):  # TensorDataset
            return dataset.tensors[1]
        else:  # General Dataset
            return torch.tensor([label for _, label in dataset])


    def run(self):
        self.setting()
        if self.args.do_train:
            best_eval_results, last_eval_resutls = self.train()
            if is_main_process():
            #     pdb.set_trace()
                self.save_results(best_eval_results, last_eval_resutls)
            rank = dist.get_rank() if is_ddp() else 0
            print(f"[Rank {rank}] Reached end of train method. Returning results...")
            return

        if self.args.do_eval:
            # else:
            best_eval_results, last_eval_resutls = {}, {}
            avg_loss, acc, class_accuracy = self.eval()
            best_eval_results['accuracy'] = acc
            last_eval_resutls['accuracy'] = acc
            if is_main_process():
            #     pdb.set_trace()
                self.save_results(best_eval_results, last_eval_resutls)
            print("Evaluation is Over")
            print("Average Loss: {} \n".format(avg_loss))
            return

        return