import numpy as np
# import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
 
from sklearn.manifold import TSNE
from scipy import linalg
import os
 
 
#Pytorch imports
 
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision import models
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split
from tqdm import tqdm
from glob import glob
from PIL import Image
from sklearn.model_selection import train_test_split
from torchinfo import summary
 
#List of classses in CIFAR10
# classes = ['airplane',
#  'automobile',
#  'bird',
#  'cat',
#  'deer',
#  'dog',
#  'frog',
#  'horse',
#  'ship',
#  'truck']
 
import pickle
 
##To read it again from file
with open('/mnt/eris-alpha/mubashar/shared_data/imagnet_budgeted_set_classes.pickle','rb') as f:
   new_classes = pickle.load(f)
 
data_dir = '/mnt/eris-alpha/datasets/ILSVRC/Data/CLS-LOC/train'
all_paths = glob('/mnt/eris-alpha/datasets/ILSVRC/Data/CLS-LOC/train/**/*.JPEG')
val_ratio = 0.05
train_paths, val_paths = train_test_split(all_paths, test_size=val_ratio, random_state=42)
len(train_paths), len(val_paths)
 
synset_mapping_file = '/mnt/eris-alpha/datasets/LOC_synset_mapping.txt'
 
class_mapping_dict = {}
class_mapping_dict_number = {}
mapping_class_to_number = {}
mapping_number_to_class = {}
i = 0
# for line in open(synset_mapping_file):
#     class_mapping_dict[line[:9].strip()] = line[9:].strip()
#     class_mapping_dict_number[i] = line[9:].strip()
#     mapping_class_to_number[line[:9].strip()] = i
#     mapping_number_to_class[i] = line[:9].strip()
#     i+=1
   
   
for line in open(synset_mapping_file):
    class_number = i
    class_label = line[9:].strip()
    class_mapping_dict[line[:9].strip()] = class_label
    class_mapping_dict_number[class_number] = class_label
    mapping_class_to_number[class_label] = class_number
    mapping_number_to_class[class_number] = line[:9].strip()
    i += 1
 
classes = list(class_mapping_dict_number.values())
 
class ImagenetDataset(Dataset):
    def __init__(self, file_paths, classes, new_classes):
        self.file_paths = file_paths
        self.transform = transform_steps = transforms.Compose([
            transforms.RandomResizedCrop(176, interpolation=transforms.InterpolationMode.BILINEAR, antialias=True),
            transforms.RandomHorizontalFlip(),
            transforms.TrivialAugmentWide(interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
            transforms.RandomErasing(p=0.1)
            # transforms.Lambda(lambda x: torch.permute(x, (1, 2, 0)).numpy())
        ])
        self.target_transform = self.groundtruthmod(classes, new_classes)
    def __len__(self):
        return len(self.file_paths)
 
    def __getitem__(self, idx):
        img = Image.open(self.file_paths[idx]).convert("RGB").resize((224,224))
        img = self.transform(img)
 
        fname = self.file_paths[idx].split("/")[-2]
        label = mapping_class_to_number[class_mapping_dict[fname]]
        return img, self.target_transform(label)
 
    def groundtruthmod(self, classes, new_classes):
        if os.path.exists("set_class_map.npy"):
            set_class_map = np.load("set_class_map.npy")
        else:
            set_class_map = np.zeros((len(classes), len(new_classes)), dtype=int)
            for i, label in enumerate(classes):
                for j, class_ in enumerate(new_classes):
                    if class_.issubset(set(classes)) and label in class_:
                        set_class_map[i, j] = 1
            np.save("set_class_map.npy", set_class_map)
   
        def mod(y):    
            return set_class_map[y]
       
        return mod
 
def collate_fn(data):
    x,y = zip(*data)
    return np.stack(x), np.stack(y)
       
train_dataset = ImagenetDataset(train_paths, classes, new_classes)
train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=5)
val_dataset = ImagenetDataset(val_paths, classes, new_classes)
val_dataloader = DataLoader(val_dataset, batch_size=128, shuffle=True, num_workers=5)
 
 
# Initializing parameters
k = 3000  #number of non-singleton focal sets
num_classes = 1000  #number of classes in CIFAR10
input_shape = (224, 224, 3)  #standard CIFAR10 input shape
 
model = torchvision.models.resnet50(progress=False, weights='IMAGENET1K_V2')
model.fc = nn.Sequential(
  # nn.Linear(in_features=2048, out_features=2800, bias=True),
  # nn.ReLU(),
  # nn.Linear(in_features=2800, out_features=3000, bias=True),
  # nn.ReLU(),
  nn.Linear(in_features=2048, out_features=len(new_classes), bias=True),
  nn.Sigmoid()
)
# model.load_state_dict(torch.load('/home/shireen/RSCNN/ImageNet/model_img.pt'))
# for param in model.parameters():
#     param.requires_grad = False
# # model.fc[0].weight.requires_grad = True
# # model.fc[0].bias.requires_grad = True
# for param in model.fc.parameters():
#     param.requires_grad = True
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu") ## specify the GPU id's, GPU id's start from 0.
checkpoint = torch.load('model_img_v2_fine-Copy4.pth', map_location=device)
model.load_state_dict(checkpoint["model"])
summary(model, input_size=(1, 3, 224, 224))
 
# model = torchvision.models.resnet50(progress=False, weights='IMAGENET1K_V1')
# model.heads.head = nn.Linear(in_features=2048, out_features=len(new_classes), bias=True)
# model.heads.append(nn.Sigmoid())
 
 
# model= nn.DataParallel(model,device_ids = [1, 2, 3])
# model.load_state_dict(torch.load('/home/shireen/RSCNN/ImageNet/model_img.pt', map_location=device))
model.to(device)
 
# Instantiate the model with the number of classes
epochs = 5
# optimizer = optim.Adam(model.parameters(), lr=1e-3)
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5, weight_decay=0.00002)
# optimizer.load_state_dict(checkpoint["optimizer"])
# optimizer = torch.optim.SGD(
#             model.parameters(),
#             lr=0.5,
#             momentum=0.9,
#             weight_decay=0.00002,
#         )
 
main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=100, eta_min=0
        )
warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
                optimizer, start_factor=0.01, total_iters=0
            )
lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
            optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[5]
        )
 
# lr_scheduler = checkpoint["lr_scheduler"]
# Co-efficient calculation for Mobeius inverse transformation
mass_co = np.zeros((len(new_classes), len(new_classes)))
 
def mass_coeff(new_classes):
    for i, A in enumerate(new_classes):
        for j, B in enumerate(new_classes):
            leng = 0
            if set(B).issubset(set(A)):
                leng = (-1) ** (len(A) - len(B))
            mass_co[j][i] = leng
    return mass_co
 
mass_coeff_matrix = mass_coeff(new_classes)
 
# Loss function with regularization
mass_coeff_matrix = mass_coeff(new_classes)
 
 
# Co-efficient matrix to calculate mass from belief functions using Mobeius inverse
mass_coeff_matrix = torch.tensor(mass_coeff_matrix,dtype=torch.float32).to(device)
 
# Hyperparameters
ALPHA = 1e-3
BETA = 1e-3
 
def BinaryCrossEntropy(y_true, y_pred):
    y_true = torch.tensor(y_true,dtype=torch.float32).to(device)
    y_true = torch.clip(y_true, torch.finfo(torch.float32).eps, 1)
    y_pred = torch.clip(y_pred, torch.finfo(torch.float32).eps, 1 - torch.finfo(torch.float32).eps)
    term_0 = (1 - y_true) * torch.log(1 - y_pred + torch.finfo(torch.float32).eps)
    term_1 = y_true * torch.log(y_pred + torch.finfo(torch.float32).eps)
   
    # Binary cross entropy loss term
    bce_loss = -torch.mean(term_0 + term_1, axis=0)
   
    # Calculating the mass from predicted belief function using matrix multiplication with co-efficient
    mass = torch.matmul(y_pred, mass_coeff_matrix)
   
    # Mass regularization term to make sure positive masses are predicted
    mass_reg = torch.mean(F.relu(-mass))
   
    # Mass regularization term 2 to add the sum of masses to the loss
    mass_sum = F.relu(torch.mean(torch.sum(mass, axis=-1)) - 1)
    # print(bce_loss.mean().item(), mass_reg.item(), mass_sum.item())
   
    # Final total loss for RS-CNN
    total_loss = bce_loss + ALPHA * mass_reg + BETA * mass_sum
    # print(bce_loss.requires_grad)
    # print(total_loss)
    return total_loss.mean()
 
 
 
 
model.train()
 
train_epoch_loss = []
val_epoch_loss = []
 
 
for i in range(1,epochs+1):
    running_loss = 0
    running_correct = 0
    total_items = 0
    pbar = tqdm(train_dataloader)
    for b, data in enumerate(pbar):
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)
       
        optimizer.zero_grad()
        outputs = model(inputs)
        # print(outputs.grad)
       
        loss = BinaryCrossEntropy(labels, outputs)
        loss.backward()
       
        # Adjust learning weights
        optimizer.step()
       
        preds = np.argmax(outputs[:,:len(classes)].detach().cpu().numpy(), axis=-1)
        running_correct += np.sum(preds == np.argmax(labels[:,:len(classes)].cpu().numpy(), axis=-1))
        total_items += len(labels)
       
        # Update Progress
        running_loss += loss.item()
        pbar.set_description(f"Training Epoch {i}/{epochs}: ")
        pbar.set_postfix({"batch_loss": loss.item(), "avg_loss": running_loss/(b+1), "accuracy": (running_correct/total_items)*100})
    train_epoch_loss.append(running_loss/(b+1))
    lr_scheduler.step()
 
 
    pbar = tqdm(val_dataloader)
    with torch.no_grad():
        val_running_loss = 0
        running_correct = 0
        total_items = 0
        for bv, data in enumerate(pbar):
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)
           
            outputs = model(inputs)
           
            val_loss = BinaryCrossEntropy(labels, outputs)
           
            preds = np.argmax(outputs[:,:len(classes)].detach().cpu().numpy(), axis=-1)
            running_correct += np.sum(preds == np.argmax(labels[:,:len(classes)].cpu().numpy(), axis=-1))
            total_items += len(labels)
           
            val_running_loss += val_loss.item()
            pbar.set_description(f"Validation Epoch {i}/{epochs}: ")
            pbar.set_postfix({"val_loss": val_running_loss/(bv+1), "accuracy": (running_correct/total_items)*100})
        val_epoch_loss.append(running_loss/(b+1))
        checkpoint = {
                    'epoch': i,
                    'model': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'lr_scheduler': lr_scheduler}
        torch.save(checkpoint, 'model_img_v2_fine_ab.pth')
        # torch.save(model.state_dict(), '/home/shireen/RSCNN/ImageNet/model_img_layers.pt')
        # break
 