import datetime
import time
import torch
import os
from tqdm import tqdm
from torch.nn import functional as F
import torch.nn as nn
import numpy as np
import clip
from collections import OrderedDict
from dassl.data import DataManager, DataManager_sf
from dassl.optim import build_optimizer, build_lr_scheduler
from dassl.utils import (
    MetricMeter, AverageMeter, tolist_if_not, count_num_param, load_checkpoint,
    save_checkpoint, mkdir_if_missing, resume_from_checkpoint,
    load_pretrained_weights
)
from dassl.modeling import build_head, build_backbone
from dassl.engine import TRAINER_REGISTRY, TrainerX
from dassl.metrics import compute_accuracy
from dassl.modeling.ops import AngularPenaltySMLoss, EntropyMaximization, InfoNCE
from dassl.evaluation import build_evaluator
from dassl.engine.dg.PromptGenerator import PromptGenerator


class clip_net(nn.Module):
    def __init__(self, cfg, model_cfg, num_classes, device, loss_type='arcface', **kwargs):
        super(clip_net, self).__init__()
        self.device = device
        # create clip backbone
        self.backbone = build_backbone(
            model_cfg.BACKBONE.NAME,
            verbose=cfg.VERBOSE,
            device=self.device,
            **kwargs
        )
        self.fdim = self.backbone._out_features
        self.head = None

        # create classifier
        self.adms_loss = AngularPenaltySMLoss(self.fdim, num_classes, loss_type=loss_type, s=cfg.ARCFACE_S,
                                              m=cfg.ARCFACE_M)
        self.infonce_loss = InfoNCE(negative_mode='paired')


    def evaluate_model_performance_info(self, model_outputs):
        # evaluate the distance between the outputs of the models

        model_outputs = [model_output.cpu() for model_output in model_outputs]

        sim1_total = 0
        sim2_total = 0
        sim3_total = 0
        sim0_1 = self.infonce_loss(model_outputs[0],model_outputs[1])
        sim0_2 = self.infonce_loss(model_outputs[0],model_outputs[2])
        sim1_2 = self.infonce_loss(model_outputs[1],model_outputs[2])

        sim1_total = sim0_1 + sim0_2
        sim2_total = sim0_1 + sim1_2
        sim3_total = sim0_2 + sim1_2

        sim1_total = sim1_total / (sim1_total + sim2_total + sim3_total)
        sim2_total = sim2_total /  (sim1_total + sim2_total + sim3_total)
        sim3_total = sim3_total /  (sim1_total + sim2_total + sim3_total)

        return [sim1_total, sim2_total, sim3_total]

    def forward_text_one(self, x, t_x, labels):

        backbone_output = self.backbone.forward_text(x, t_x)
        embed_output = backbone_output / backbone_output.norm(dim=-1, keepdim=True)

        y_class = self.adms_loss.fc(embed_output)
        y_loss = self.adms_loss(embed_output, labels)

        return embed_output, y_class, y_loss

    def forward_img(self, x, norm=True):
        backbone_output = self.backbone.forward_image(x)
        embed_output = backbone_output / backbone_output.norm(dim=-1, keepdim=True)

        if norm:
            embed_output = embed_output / embed_output.norm(dim=-1, keepdim=True)

        y_class = self.adms_loss.fc(embed_output)
        return y_class

    def predictor(self, feat, teat):
        feat_p = feat / feat.norm(dim=-1, keepdim=True)
        teat_p = teat / teat.norm(dim=-1, keepdim=True)
        scores = 100.0 * feat_p @ teat_p.T
        return scores


@TRAINER_REGISTRY.register()
class ECS_clip(TrainerX):
    def __init__(self, cfg):
        self._models = OrderedDict()
        self._optims = OrderedDict()
        self._scheds = OrderedDict()
        self._writer = None
        self.check_cfg(cfg)
        if torch.cuda.is_available() and cfg.USE_CUDA:
            self.device = torch.device("cuda")
        else:
            self.device = torch.device("cpu")
        # Save as attributes some frequently used variables
        self.start_epoch = self.epoch = 0
        self.max_epoch = cfg.OPTIM.MAX_EPOCH
        self.output_dir = cfg.OUTPUT_DIR
        self.cfg = cfg

        self.build_model()
        self.init_train_data()
        self.build_data_loader()
        self.evaluator = build_evaluator(cfg, lab2cname=self.lab2cname)
        self.best_result = -np.inf


    def build_model(self):
        cfg = self.cfg
        print("Building model")

        if self.cfg.DATASET.NAME == 'PACS_SF':
            self.num_classes = 7
        elif self.cfg.DATASET.NAME == 'OfficeHomeDG_SF':
            self.num_classes = 65
        elif self.cfg.DATASET.NAME == 'VLCS_SF':
            self.num_classes = 5
        elif self.cfg.DATASET.NAME == 'DomainNet_SF':
            self.num_classes = 345
        elif self.cfg.DATASET.NAME == 'TerraIncognita_SF':
            self.num_classes = 10

        self.model = clip_net(cfg, cfg.MODEL, self.num_classes, self.device)
        if cfg.MODEL.INIT_WEIGHTS:
            load_pretrained_weights(self.model, cfg.MODEL.INIT_WEIGHTS)
        self.model.to(self.device)

        self.optim = build_optimizer(self.model, cfg.OPTIM)
        self.sched = build_lr_scheduler(self.optim, cfg.OPTIM)
        self.register_model("model", self.model, self.optim, self.sched)

    def init_train_data(self):
        txts_dir_path = self.cfg.TXTS_PATH
        txt_path = os.path.join(txts_dir_path, self.cfg.DATASET.NAME + '.txt')

        with open(txt_path, 'r') as f:
            lines = f.read().splitlines()
        class_dict = {index: value for index, value in enumerate(lines)}
        classnames = list(class_dict.values())
        self.num_classes = len(classnames)
        self.prompt_generater = PromptGenerator(self.cfg, classnames, self.model.backbone, self.device)

    def build_data_loader(self):
        dm = DataManager_sf(self.cfg, self.prompt_generater)
        self.train_loader_x = dm.train_loader_x

        self.val_loader = dm.val_loader
        self.test_loader = dm.test_loader

        self.num_classes = dm.num_classes
        self.num_source_domains = dm.num_source_domains
        self.lab2cname = dm.lab2cname
        self.dm = dm

    def forward_backward(self, batch):
        input_embed, input_token, target = self.parse_batch_train(batch)

        backbone_output_1,y_class_1, y_loss_1 = self.model.forward_text_one(input_embed[0], input_token[0], target)
        self.model_backward_and_update(y_loss_1)
        backbone_output_2,y_class_2, y_loss_2 = self.model.forward_text_one(input_embed[1], input_token[1], target)
        self.model_backward_and_update(y_loss_2)
        backbone_output_3,y_class_3, y_loss_3 = self.model.forward_text_one(input_embed[2], input_token[2], target)
        self.model_backward_and_update(y_loss_3)

        model_outputs = [backbone_output_1, backbone_output_2, backbone_output_3]
        results = self.model.evaluate_model_performance_info(model_outputs)
        weights = results
        backbone_output_fusion = sum([model_output * weight for model_output, weight in zip(model_outputs, weights)])
        y_class_fusion = self.model.adms_loss.fc(backbone_output_fusion)
        y_loss_fusion = sum([y_loss * weight for y_loss, weight in zip([y_loss_1, y_loss_2, y_loss_3], weights)])
        self.model_backward_and_update(y_loss_fusion)

        loss_summary = {
            "y_loss": y_loss_fusion.item(),
            "acc": compute_accuracy(y_class_fusion, target)[0].item(),
        }

        if (self.batch_idx + 1) == self.num_batches:
            self.update_lr()

        return loss_summary

    def parse_batch_train(self, batch):
        input_embed = batch["embedding"]
        input_token = batch["tokenized_prompts"]
        target = batch["label"]

        input_embed = [torch.tensor(embed).to(self.device) for embed in input_embed]
        input_token = [torch.tensor(token).to(self.device) for token in input_token]

        target = target.to(self.device)
        return input_embed, input_token, target

    def model_inference(self, input):
        return self.model.forward_img(input, norm=True)

    @torch.no_grad()
    def test(self, split=None):
        """A generic testing pipeline."""
        self.set_model_mode("eval")
        self.evaluator.reset()
        result = []
        if split is None:
            split = self.cfg.TEST.SPLIT

        if split == "val" and self.val_loader is not None:
            data_loader = self.val_loader
        else:
            split = "test"  # in case val_loader is None
            data_loader = self.test_loader
        for data_loader_domain in data_loader:
            print(f"Evaluate on the *{split}* set")
            for batch_idx, batch in enumerate(tqdm(data_loader_domain)):
                input, label = self.parse_batch_test(batch)
                output = self.model_inference(input)
                self.evaluator.process(output, label)

            results = self.evaluator.evaluate()

            for k, v in results.items():
                tag = f"{split}/{k}"
                self.write_scalar(tag, v, self.epoch)
            result.append(list(results.values())[0])
            self.evaluator.reset()

        return result

    def after_epoch(self):
        curr_result = self.test()
        curr_result = np.mean(curr_result)
        is_best = curr_result > self.best_result
        if is_best:
            self.best_result = curr_result
            self.save_model(
                self.epoch,
                self.output_dir,
                val_result=curr_result,
                model_name="model-best.pth.tar"
            )
        # Show elapsed time
        elapsed = round(time.time() - self.time_start)
        elapsed = str(datetime.timedelta(seconds=elapsed))
        print(f"Elapsed: {elapsed}")

    def run_epoch(self):
        self.set_model_mode("train")
        losses = MetricMeter()
        batch_time = AverageMeter()
        data_time = AverageMeter()
        self.num_batches = len(self.train_loader_x)

        end = time.time()
        for self.batch_idx, batch in enumerate(self.train_loader_x):
            data_time.update(time.time() - end)
            loss_summary = self.forward_backward(batch)
            batch_time.update(time.time() - end)
            losses.update(loss_summary)

            meet_freq = (self.batch_idx + 1) % self.cfg.TRAIN.PRINT_FREQ == 0
            only_few_batches = self.num_batches < self.cfg.TRAIN.PRINT_FREQ
            if meet_freq or only_few_batches:
                nb_remain = 0
                nb_remain += self.num_batches - self.batch_idx - 1
                nb_remain += (
                                     self.max_epoch - self.epoch - 1
                             ) * self.num_batches
                eta_seconds = batch_time.avg * nb_remain
                eta = str(datetime.timedelta(seconds=int(eta_seconds)))

                info = []
                info += [f"epoch [{self.epoch + 1}/{self.max_epoch}]"]
                info += [f"batch [{self.batch_idx + 1}/{self.num_batches}]"]
                info += [f"time {batch_time.val:.3f} ({batch_time.avg:.3f})"]
                info += [f"data {data_time.val:.3f} ({data_time.avg:.3f})"]
                info += [f"{losses}"]
                info += [f"lr {self.get_current_lr():.4e}"]
                info += [f"eta {eta}"]
                print(" ".join(info))

            n_iter = self.epoch * self.num_batches + self.batch_idx
            for name, meter in losses.meters.items():
                self.write_scalar("train/" + name, meter.avg, n_iter)
            self.write_scalar("train/lr", self.get_current_lr(), n_iter)

            end = time.time()

    def before_epoch(self):
        if self.cfg.TRAINER.REFRESH in ["GaussMix"]:
            print("*****refresh style : ",self.cfg.TRAINER.REFRESH,"*****")
            self.prompt_generater.refresh_style()

    def train(self):
        self.before_train()
        start_time = time.time()
        for self.epoch in range(self.start_epoch, self.max_epoch):
            self.before_epoch()
            self.run_epoch()

            self.after_epoch()
        final_time = time.time()-start_time
        # print(f"Training Time：{final_time} s")
        self.after_train()
