import argparse
import random
from torch import optim
import pandas as pd
from utils import *
from dataset import *
from model import *
from torch import nn, Tensor
from models.Trans_En import TransformerModel
import datetime

# os.environ["CUDA_VISIBLE_DEVICES"] = '5'
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def cross_entropy(preds, targets, reduction='none'):
    log_softmax = nn.LogSoftmax(dim=-1)
    loss = (-targets * log_softmax(preds)).sum(1)
    if reduction == "none":
        return loss
    elif reduction == "mean":
        return loss.mean()

################## Trainer (change loss)
class Trainer:
    def __init__(self):
        set_seed(seed)
        self.model = TransformerModel(in_channel=2048, d_model=512)

        print('Model Size: {}'.format(sum(p.numel() for p in self.model.parameters())))
        self.es = EarlyStop(patience=args.patience)

    def train(self, save_dir, num_epochs):
        self.model.to(device)
        optimizer = optim.Adam(self.model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=config.step_size, gamma=config.gamma)
        best_score = -10000 # for choise best model

        for epoch in range(num_epochs):
            self.train_single_epoch(train_loader, optimizer, scheduler)

            self.test(epoch, test_loader)



    def train_single_epoch(self, train_loader, optimizer, scheduler):
        self.model.train()
        epoch_sum, nums = 0, 0
        for i, item in enumerate(train_loader):
            nums += 1

            # Getting view_c and view_h embedding
            samples_c = item[0].to(device).permute(0, 2, 1)  # [20,100, 2048]
            samples_h = item[1].to(device).permute(0, 2, 1)
            out_c = self.model(samples_c)  # input:[b,l,c]
            out_h = self.model(samples_h)  # input:[b,l,c]

            # Calculating the Loss
            self.temperature = 1
            logits = (out_c @ out_h.T) / self.temperature
            viewc_similarity = out_c @ out_c.T
            viewh_similarity = out_h @ out_h.T
            targets = F.softmax(
                (viewc_similarity + viewh_similarity) / 2 * self.temperature, dim=-1
            )
            c_loss = cross_entropy(logits, targets, reduction='none')
            h_loss = cross_entropy(logits.T, targets.T, reduction='none')
            loss = (c_loss + h_loss) / 2.0  # shape: (batch_size)
            loss = loss.mean()
            loss_sum += loss

            optimizer.zero_grad()  ## each batch, each update
            loss.backward()
            optimizer.step()
            # print('can train----------')
            if nums % 100 == 0:
                print('{} | Epoch {} ({}) loss: {}'.format(datetime.datetime.now(), epoch, nums, epoch_sum / nums))

        print('{} ---------------------------| Epoch {} loss is {}'.format(datetime.datetime.now(), epoch,
                                                                           loss_sum / (len(train_loader))))

    def test(self, epoch, test_loader): # could for debug
        self.model.eval()
        self.model.to(device)
        epoch_loss, ce_loss, smooth_loss = 0.0, 0.0, 0.0
        nums = 0
        loss_sum = 0
        with torch.no_grad():
            for i, item in enumerate(test_loader):
                nums += 1
                # Getting view_c and view_h embedding
                samples_c = item[0].to(device).permute(0, 2, 1)  # [20,100, 2048]
                samples_h = item[1].to(device).permute(0, 2, 1)
                out_c = self.model(samples_c)  # input:[b,l,c]
                out_h = self.model(samples_h)  # input:[b,l,c]

                # Calculating the Loss
                self.temperature = 1
                logits = (out_c @ out_h.T) / self.temperature
                viewc_similarity = out_c @ out_c.T
                viewh_similarity = out_h @ out_h.T
                targets = F.softmax(
                    (viewc_similarity + viewh_similarity) / 2 * self.temperature, dim=-1
                )
                c_loss = cross_entropy(logits, targets, reduction='none')
                h_loss = cross_entropy(logits.T, targets.T, reduction='none')
                loss = (c_loss + h_loss) / 2.0  # shape: (batch_size)
                loss = loss.mean()
                loss_sum += loss

            print('{} ---------------------------| Test Epoch {} loss is {}'.format(datetime.datetime.now(), epoch,
                                                                               loss_sum / (len(test_loader))))

        return results['total_score']

###------------------------main here ------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
parser = argparse.ArgumentParser()
parser.add_argument('--action', default='train') #train
parser.add_argument('--feature_path', type=str, default='/data/peiyao/data/all_data/Assembly101/TSM_features') #/mnt/data/zhanzhong/assembly/
parser.add_argument('--dataset', default="assembly")
parser.add_argument('--split', default='train_val')  # or 'train_val'
parser.add_argument('--seed', default='42')
parser.add_argument('--test_aug', type=int, default=0)
parser.add_argument('--patience', type=int, default=20)
args = parser.parse_args()

seed = int(args.seed)
set_seed(seed)
VIEWS = ['C10095_rgb', 'C10115_rgb', 'C10118_rgb', 'C10119_rgb', 'C10379_rgb', 'C10390_rgb', 'C10395_rgb', 'C10404_rgb',
             'HMC_21176875_mono10bit', 'HMC_84346135_mono10bit', 'HMC_21176623_mono10bit', 'HMC_84347414_mono10bit',
             'HMC_21110305_mono10bit', 'HMC_84355350_mono10bit', 'HMC_21179183_mono10bit', 'HMC_84358933_mono10bit']

config = dotdict(
    epochs=200,
    dataset=args.dataset,
    feature_size=2048,
    gamma=0.5,
    step_size=200,
    split=args.split,
    config.learning_rate = 1e-4
    config.weight_decay = 1e-4
    config.batch_size = 20
    )



TYPE = '/c2f_{}'.format(args.seed)
model_dir = "./models/" + args.dataset + "/" + args.split + TYPE
results_dir = "./results/" + args.dataset + "/" + args.split + TYPE

if not os.path.exists(model_dir):
    os.makedirs(model_dir)
if not os.path.exists(results_dir):
    os.makedirs(results_dir)

anno_data_path = "/data/peiyao/data/all_data/Assembly101/annotations/"
vid_list_path = anno_data_path + "coarse-annotations/coarse_splits/"
gt_path = anno_data_path + "coarse-annotations/coarse_labels/"
mapping_file = anno_data_path + "coarse-annotations/actions.csv"
features_path = args.feature_path

config.features_path = features_path
config.gt_path = gt_path
config.VIEWS = VIEWS

actions = pd.read_csv(mapping_file, header=0, names=['action_id', 'verb_id', 'noun_id', 'action_cls', 'verb_cls', 'noun_cls'])
actions_dict, label_dict = dict(), dict()
for _, act in actions.iterrows():
    actions_dict[act['action_cls']] = int(act['action_id'])
    label_dict[int(act['action_id'])] = act['action_cls']

num_classes = len(actions_dict)
assert num_classes == config.num_class

############################dataloader
def _init_fn(worker_id):
    np.random.seed(int(seed))

###########################postprocessor
postprocessor = PostProcess(config, label_dict, actions_dict, gt_path).to(device)

trainer = Trainer()

train_dataset = AugmentDataset(config, fold='train', fold_file_name=vid_list_path, actions_dict=actions_dict, zoom_crop=(0.5, 2))
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=config.batch_size, shuffle=True,
                                           pin_memory=True, num_workers=0, # collate_fn=collate_fn_override,
                                           worker_init_fn=_init_fn)

test_dataset = AugmentDataset_val(config, fold='val', fold_file_name=vid_list_path, actions_dict=actions_dict, zoom_crop=(0.5, 2))
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=config.batch_size, shuffle=False,
                                          pin_memory=True, num_workers=0,  #collate_fn=collate_fn_override,
                                          worker_init_fn=_init_fn)

trainer.train(model_dir, num_epochs=config.epochs)



