from codecs import ignore_errors
from collections import OrderedDict, defaultdict

import torch
import torch.nn as nn
import tqdm
from torch.nn import CrossEntropyLoss
from torchvision.models.segmentation import deeplabv3_resnet101

import tasks
from dataset import transform
from phase_1.finetune import ram_plus

class MultiHeadBinaryClassifier(nn.Module):

    def __init__(self, in_channels: int):
        super().__init__()

        self.classifiers = nn.Conv2d(in_channels, 2, kernel_size=1, stride=1)

    def forward(self, features: torch.Tensor) -> torch.Tensor:
        output = self.classifiers(features)
        return output


import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
import os


def ddp_setup():
    dist.init_process_group(backend="nccl")
    local_rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(local_rank)


class CustomDeepLabV3(nn.Module):
    def __init__(self, num_classes: int, pretrained: bool = True):
        super().__init__()
        weights = 'COCO_WITH_VOC_LABELS_V1' if pretrained else None
        deeplab_model = deeplabv3_resnet101(weights=weights, progress=True)

        self.backbone = deeplab_model.backbone
        # self.aspp = deeplab_model.classifier[0]  # ASPP is the first element
        segmentation_head = nn.Sequential(deeplab_model.classifier[0], MultiHeadBinaryClassifier(256))
        self.classifier_head = nn.ModuleList()
        for i in range(num_classes):
            self.classifier_head.append(segmentation_head)
        self.phase_1 = ram_plus(pretrained = parser.checkpoint,image_size=config['image_size'], vit=config['vit'], vit_grad_ckpt=config['vit_grad_ckpt'],
                                vit_ckpt_layer=config['vit_ckpt_layer'])
    def forward(self, x: torch.Tensor, label: int) -> dict[str, torch.Tensor]:
        input_shape = x.shape[-2:]

        features = self.backbone(x)['out']
        label - self.phase_1(x)
        output = self.classifier_head[label](features)
        # features = self.aspp(features)
        #
        # output = self.classifier_head(features, label)
        outputs = nn.functional.interpolate(output, size=input_shape, mode="bilinear", align_corners=False)

        return outputs


# ddp_setup()

# rank = int(os.environ["RANK"])
# world_size = int(os.environ["WORLD_SIZE"])
# local_rank = int(os.environ["LOCAL_RANK"])
# device = f"cuda:{local_rank}"
device = "cuda:0"
import argparser

parser = argparser.get_argparser()
parser.add_argument('--config', default=r'./phase_1/config/pretrain.yaml')
parser.add_argument("--model-type", type=str, choices=("ram_plus", "ram", "tag2text"), required=True)
parser.add_argument('--output-dir', default='output/Pretrain')
import yaml
config = yaml.load(open(parser.config, 'r'), Loader=yaml.Loader)
opts = parser.parse_args()
opts = argparser.modify_command_options(opts)
opts.dataset = "voc"
opts.task = "1-1"
opts.data_root = r"./data/VOC2012"

train_transform = transform.Compose(
    [
        transform.RandomResizedCrop(opts.crop_size, (0.5, 2.0)),
        transform.RandomHorizontalFlip(),
        transform.ToTensor(),
        transform.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)
val_transform = transform.Compose(
    [
        transform.Resize((opts.crop_size, opts.crop_size)),
        transform.ToTensor(),
        transform.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)
from dataset import VOCSegmentationIncremental, AdeSegmentationIncremental


def get_dataset(_opts, task):
    _opts.step = task
    labels, labels_old, path_base = tasks.get_task_labels(_opts.dataset, _opts.task, task)
    val_labels, val_labels_old, _ = tasks.get_task_labels(_opts.dataset, "offline", 0)
    path_base += "-ov-ood"
    if _opts.dataset == "voc":
        dataset = VOCSegmentationIncremental
    elif _opts.dataset == "ade":
        dataset = AdeSegmentationIncremental
    else:
        raise NotImplementedError(f"{_opts.dataset} is not implemented.")
    train_dataset = dataset(
        root=_opts.data_root,
        train=True,
        transform=train_transform,
        labels=list(labels),
        labels_old=list(labels_old),
        idxs_path=path_base + f"/train-{_opts.step}.npy",
        masking=not _opts.no_mask,
        overlap=True,
        disable_background=_opts.disable_background,
        data_masking=_opts.data_masking,
        step=_opts.step,
        create_ood=True
    )
    val_dataset = dataset(
        root=_opts.data_root,
        train=False,
        transform=val_transform,
        labels=list(labels),
        labels_old=list(labels_old),
        idxs_path=path_base + f"/val-task-{_opts.step}.npy",
        masking=not _opts.no_mask,
        overlap=True,
        disable_background=_opts.disable_background,
        data_masking=_opts.data_masking,
        step=0
    )
    total_val_dataset = dataset(
        root=_opts.data_root,
        train=False,
        transform=val_transform,
        labels=list(val_labels),
        labels_old=list(val_labels_old),
        idxs_path=path_base + f"/val-{_opts.step}.npy",
        masking=not _opts.no_mask,
        overlap=True,
        disable_background=_opts.disable_background,
        data_masking=_opts.data_masking,
        step=0
    )
    return train_dataset, val_dataset,total_val_dataset


# _,test_val = get_dataset(opts,4)

from torch.utils.data import DataLoader
from dice import DiceLoss
from torchvision.ops import sigmoid_focal_loss
import torch.nn.functional as F

import segmentation_models_pytorch as smp


class CombinedLoss(nn.Module):
    def __init__(self, focal_weight=1.0, dice_weight=0.05):
        super().__init__()
        self.focal_weight = focal_weight
        self.dice_weight = dice_weight
        self.focal_loss = smp.losses.FocalLoss(mode="multiclass", gamma=2.0)
        self.dice_loss = smp.losses.DiceLoss(mode="multiclass", from_logits=True)

    def forward(self, outputs, targets):
        targets = targets.long()
        # _targets = F.one_hot(targets.long(), num_classes=output_class).permute(0, 3, 1, 2).float()
        loss1 = self.focal_loss(outputs, targets)
        loss2 = self.dice_loss(outputs, targets)
        total_loss = self.focal_weight * loss1 + self.dice_weight * loss2
        return total_loss


from torch.optim import SGD, AdamW


def train(task: int):
    # ddp_setup()
    #
    # rank = int(os.environ["RANK"])
    # world_size = int(os.environ["WORLD_SIZE"])
    # local_rank = int(os.environ["LOCAL_RANK"])
    # device = f"cuda:{local_rank}"
    model = CustomDeepLabV3(20)
    for i in model.parameters():
        i.requires_grad = False
    for i in model.classifier_head.parameters():
        i.requires_grad = True
    train_dataset, specific_val,val_dataset = get_dataset(opts, task - 1)
    params = [p for p in model.parameters() if p.requires_grad]
    optim = AdamW(params, lr=0.00006)
    loss_fn = CombinedLoss(focal_weight=0.25, dice_weight=0.75)
    batch_size = 200
    label_num = task  # from task 3
    model = model.to(device)
    # model = DDP(model,device_ids=[local_rank],find_unused_parameters=True)
    epoch = 71
    train_dataloader, total_val_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False,
                                                  num_workers=4), DataLoader(val_dataset, batch_size=batch_size,
                                                                             shuffle=False, num_workers=4)
    val_dataloader = DataLoader(specific_val, batch_size=batch_size, shuffle=False, num_workers=4)
    ignore_index = 255
    for e in range(epoch):
        loss_list = []
        background_iou_list = []
        foreground_iou_list = []
        total_tp, total_fp, total_fn, total_tn = torch.zeros(2, device="cpu"), torch.zeros(2,
                                                                                           device="cpu"), torch.zeros(2,
                                                                                                                      device="cpu"), torch.zeros(
            2, device="cpu")
        pbar = tqdm.tqdm(train_dataloader, total=len(train_dataloader), desc="Training")
        for image, label in pbar:
            other_mask = label != label_num
            optim.zero_grad()
            image = image.to(device)
            new_label = torch.where(label == ignore_index, torch.zeros_like(label), label)
            new_label = torch.where(other_mask, torch.zeros_like(new_label), new_label)
            new_label = torch.where(new_label == label_num, torch.ones_like(new_label), new_label)
            new_label = new_label.to(device)
            pred = model(image, label_num - 1)
            loss = loss_fn(pred, new_label)
            loss.backward()
            optim.step()
            pbar.set_postfix(loss=f'{loss.item():.4f}')
            new_label = new_label.long()
            _pred = torch.argmax(pred, dim=1).long()
            tp, fp, fn, tn = smp.metrics.get_stats(_pred, new_label, mode="multiclass", num_classes=2)
            total_tp += tp.sum(dim=0)
            total_fp += fp.sum(dim=0)
            total_fn += fn.sum(dim=0)
            total_tn += tn.sum(dim=0)
        final_iou = smp.metrics.iou_score(total_tp[None], total_fp[None], total_fn[None], total_tn[None],
                                          reduction="none").mean(dim=0)

        print(f"background iou:{final_iou[0].item()}")
        print(f"foreground iou:{final_iou[1].item()}")
        # if e % 5 ==0:
        #     print(f"loss:{sum(loss_list) / len(loss_list)}")
        #     print(f"background iou:{sum(background_iou_list) / len(background_iou_list)}")
        #     print(f"foreground iou:{sum(foreground_iou_list) / len(foreground_iou_list)}")
        checkpoint = model.state_dict()
        torch.save(checkpoint, f"./mycheckpoint/ood/voc-task{label_num - 1}.pth")
        if e % 20 == 0:
            total_tp, total_fp, total_fn, total_tn = torch.zeros(2, device="cpu"), torch.zeros(2,
                                                                                               device="cpu"), torch.zeros(
                2, device="cpu"), torch.zeros(2, device="cpu")
            with torch.no_grad():
                for image, label in tqdm.tqdm(total_val_dataloader, total=len(total_val_dataloader)):
                    mask = (label == ignore_index)
                    other_mask = (label != label_num)
                    image = image.to(device)
                    new_label = torch.where(mask, torch.zeros_like(label), label)
                    new_label = torch.where(other_mask, torch.zeros_like(new_label), new_label)
                    new_label = torch.where(new_label == label_num, torch.ones_like(new_label), new_label)
                    new_label = new_label.to(device)
                    pred = model(image, label_num - 1)
                    new_label = new_label.long()
                    _pred = torch.argmax(pred, dim=1).long()
                    tp, fp, fn, tn = smp.metrics.get_stats(_pred, new_label, mode="multiclass", num_classes=2)
                    total_tp += tp.sum(dim=0)
                    total_fp += fp.sum(dim=0)
                    total_fn += fn.sum(dim=0)
                    total_tn += tn.sum(dim=0)
                final_iou = smp.metrics.iou_score(total_tp[None], total_fp[None], total_fn[None], total_tn[None],
                                                  reduction="none").mean(dim=0)
                print(f"segmentation only background iou:{final_iou[0].item()}")
                print(f"segmentation only foreground iou:{final_iou[1].item()}")
                with open(f"./training_results/ood/offline_{task}_segmentation_only.txt", "a") as f:
                    f.write(f"task: {task}, epoch: {e}, background_iou for val is {final_iou[0].item()}\n")
                    f.write(f"task: {task}, epoch: {e}, foreground_iou for val is {final_iou[1].item()}\n")
        if e%10 ==0:
            total_tp, total_fp, total_fn, total_tn = torch.zeros(2, device="cpu"), torch.zeros(2,
                                                                                               device="cpu"), torch.zeros(
                2, device="cpu"), torch.zeros(2, device="cpu")
            with torch.no_grad():
                for image, label in tqdm.tqdm(val_dataloader, total=len(val_dataloader)):
                    mask = (label == ignore_index)
                    other_mask = (label != label_num)
                    image = image.to(device)
                    new_label = torch.where(mask, torch.zeros_like(label), label)
                    new_label = torch.where(other_mask, torch.zeros_like(new_label), new_label)
                    new_label = torch.where(new_label == label_num, torch.ones_like(new_label), new_label)
                    new_label = new_label.to(device)
                    pred = model(image, label_num - 1)
                    new_label = new_label.long()
                    _pred = torch.argmax(pred, dim=1).long()
                    tp, fp, fn, tn = smp.metrics.get_stats(_pred, new_label, mode="multiclass", num_classes=2)
                    total_tp += tp.sum(dim=0)
                    total_fp += fp.sum(dim=0)
                    total_fn += fn.sum(dim=0)
                    total_tn += tn.sum(dim=0)
                final_iou = smp.metrics.iou_score(total_tp[None], total_fp[None], total_fn[None], total_tn[None],
                                                  reduction="none").mean(dim=0)
                print(f"oracle background iou:{final_iou[0].item()}")
                print(f"oracle foreground iou:{final_iou[1].item()}")
                with open(f"./training_results/ood/offline_{task}_oracle.txt", "a") as f:
                    f.write(f"task: {task}, epoch: {e}, background_iou for val is {final_iou[0].item()}\n")
                    f.write(f"task: {task}, epoch: {e}, foreground_iou for val is {final_iou[1].item()}\n")


# train(5)

def val_for_target_task(task:int):
    ckp = torch.load(rf"./mycheckpoint/ood/voc-task{task-1}.pth")
    model = CustomDeepLabV3(20)
    model.load_state_dict(ckp,strict=True)
    model = model.to(device)
    for i in model.parameters():
        i.requires_grad = False
    for i in model.classifier_head.parameters():
        i.requires_grad = True
    model = model.eval()
    train_dataset,val_dataset,_ = get_dataset(opts, task - 1)
    batch_size = 100
    train_dataloader, val_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False,
                                                  num_workers=4), DataLoader(val_dataset, batch_size=batch_size,
                                                                             shuffle=False, num_workers=4)
    ignore_index= 255
    label_num = task

    total_tp, total_fp, total_fn, total_tn = torch.zeros(2, device="cpu"), torch.zeros(2,
                                                                                       device="cpu"), torch.zeros(
        2, device="cpu"), torch.zeros(2, device="cpu")
    with torch.no_grad():
        for image, label in tqdm.tqdm(val_dataloader, total=len(val_dataloader)):
            mask = (label == ignore_index)
            other_mask = (label != label_num)
            image = image.to(device)
            new_label = torch.where(mask, torch.zeros_like(label), label)
            new_label = torch.where(other_mask, torch.zeros_like(new_label), new_label)
            new_label = torch.where(new_label == label_num, torch.ones_like(new_label), new_label)
            new_label = new_label.to(device)
            pred = model(image, label_num - 1)
            new_label = new_label.long()
            _pred = torch.argmax(pred, dim=1).long()
            tp, fp, fn, tn = smp.metrics.get_stats(_pred, new_label, mode="multiclass", num_classes=2)
            total_tp += tp.sum(dim=0)
            total_fp += fp.sum(dim=0)
            total_fn += fn.sum(dim=0)
            total_tn += tn.sum(dim=0)
        final_iou = smp.metrics.iou_score(total_tp[None], total_fp[None], total_fn[None], total_tn[None],
                                          reduction="none").mean(dim=0)
        print(f"background iou:{final_iou[0].item()}")
        print(f"foreground iou:{final_iou[1].item()}")

# val_for_target_task(5)
# for i in range(16,21):
#     train(i)
def create_total_model(task_num):
    model = CustomDeepLabV3(task_num)
    checkpoint_dir = r"./mycheckpoint"
    ckp_1 = torch.load(os.path.join(checkpoint_dir, "voc-task0.pth"))
    final_ckp = OrderedDict()
    for k, v in ckp_1.items():
        if k.startswith("backbone."):
            final_ckp[k[len("backbone."):]] = v
    model.backbone.load_state_dict(final_ckp, strict=True)
    print("finish load backbone")
    for i in tqdm.tqdm(range(task_num), total=task_num):
        name = f"voc-task{i}.pth"
        checkpoint_path = os.path.join(checkpoint_dir, name)
        ckp = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
        prefix = f"classifier_head.{i}."
        head_ckp = OrderedDict()
        for k, v in ckp.items():
            if k.startswith(prefix):
                new_key = k[len(prefix):]
                head_ckp[new_key] = v
        model.classifier_head[i].load_state_dict(head_ckp, strict=True)
        print(f"finish load {i} task")
    return model


def val():
    global opts
    model_list = [CustomDeepLabV3(20).to(device) for i in range(20)]
    for i in range(20):
        ckp = torch.load(rf"./mycheckpoint/ood/voc-task{i}.pth")
        model_list[i].load_state_dict(ckp, strict=True)
    # ckp = torch.load("./finial.pth")
    # model.load_state_dict(ckp,strict=True)
    opts.task = "offline"
    total_tp = torch.zeros(21, device="cpu")
    total_fp = torch.zeros(21, device="cpu")
    total_fn = torch.zeros(21, device="cpu")
    total_tn = torch.zeros(21, device="cpu")
    train_dataset, val_dataset,_ = get_dataset(opts, 0)
    dataloader = DataLoader(val_dataset,shuffle=False,num_workers=4,batch_size=1)
    iou_dict = defaultdict(list)
    for i in tqdm.tqdm(range(len(val_dataset)-10), total=len(val_dataset)-10):
        image, label = val_dataset[i]
        image = image.to(device)
        label = label.to(device)
        new_label = torch.where(label == 255, torch.zeros_like(label), label)
        label_index = torch.unique(new_label).cpu().tolist()
        pred_list = []
        for index in label_index:
            if index != 0:
                with torch.no_grad():
                    pred = model_list[index - 1](image[None].repeat(2, 1, 1, 1), index - 1)
                pred = torch.argmax(pred, dim=1).long()
                index_label = torch.where(new_label == index, torch.ones_like(new_label), torch.zeros_like(new_label))
                _pred = torch.where(pred == 1, index, torch.zeros_like(pred))[0]
                pred_list.append(_pred)
        stacked_pred = torch.stack(pred_list, dim=0)
        final_pred = torch.max(stacked_pred, dim=0)[0]
        tp, fp, fn, tn = smp.metrics.get_stats(final_pred[None], new_label[None], mode="multiclass", num_classes=21)
        total_tp += tp.sum(dim=0)
        total_fp += fp.sum(dim=0)
        total_fn += fn.sum(dim=0)
        total_tn += tn.sum(dim=0)
        # iou_list = smp.metrics.iou_score(tp, fp, fn, tn, reduction="none")[0]
        # for iou, index in zip(iou_list, label_index):
        #     # print(f"class:{index}: iou: {iou}")
        #     iou_dict[index].append(iou.item())
    per_class_iou = \
        smp.metrics.iou_score(total_tp[None], total_fp[None], total_fn[None], total_tn[None], reduction="none")[0]
    for class_idx, iou in enumerate(per_class_iou):
        print(f"Class {class_idx}: IoU = {iou.item():.4f}")
    # print(iou_dict)
    # with open("./iou_final.txt", "a") as f:
    #     for k, v in iou_dict.items():
    #         f.write(f"class: {k}; iou: {sum(v) / len(v)}; num: {len(v)}\n")

val()